diff --git a/src/SingleStoreConnector/ColumnReaders/ColumnReader.cs b/src/SingleStoreConnector/ColumnReaders/ColumnReader.cs index e0bca6a64..e3de60cce 100644 --- a/src/SingleStoreConnector/ColumnReaders/ColumnReader.cs +++ b/src/SingleStoreConnector/ColumnReaders/ColumnReader.cs @@ -8,6 +8,30 @@ internal abstract class ColumnReader { public static ColumnReader Create(bool isBinary, ColumnDefinitionPayload columnDefinition, SingleStoreConnection connection) { + switch (columnDefinition.ExtendedTypeCode) + { + case SingleStoreExtendedTypeCode.Bson: + return BytesColumnReader.Instance; + + case SingleStoreExtendedTypeCode.Vector: + return columnDefinition.VectorElementType switch + { + SingleStoreVectorElementType.F32 => VectorFloat32ColumnReader.Instance, + SingleStoreVectorElementType.F64 => VectorFloat64ColumnReader.Instance, + SingleStoreVectorElementType.I8 => VectorInt8ColumnReader.Instance, + SingleStoreVectorElementType.I16 => VectorInt16ColumnReader.Instance, + SingleStoreVectorElementType.I32 => VectorInt32ColumnReader.Instance, + SingleStoreVectorElementType.I64 => VectorInt64ColumnReader.Instance, + null => throw new FormatException("VECTOR column is missing VectorElementType metadata."), + _ => throw new NotSupportedException( + $"Unsupported VECTOR element type: {columnDefinition.VectorElementType}."), + }; + + case SingleStoreExtendedTypeCode.None: + default: + break; + } + var isUnsigned = (columnDefinition.ColumnFlags & ColumnFlags.Unsigned) != 0; switch (columnDefinition.ColumnType) { diff --git a/src/SingleStoreConnector/ColumnReaders/VectorColumnReader.cs b/src/SingleStoreConnector/ColumnReaders/VectorColumnReader.cs new file mode 100644 index 000000000..769f8fc49 --- /dev/null +++ b/src/SingleStoreConnector/ColumnReaders/VectorColumnReader.cs @@ -0,0 +1,153 @@ +using System.Buffers.Binary; +using System.Runtime.InteropServices; +using SingleStoreConnector.Protocol.Payloads; + +namespace SingleStoreConnector.ColumnReaders; + +internal abstract class VectorColumnReaderBase : ColumnReader +{ + protected static void ValidateLength(ColumnDefinitionPayload columnDefinition, int dataLength, int elementSize, string elementTypeName) + { + if (dataLength % elementSize != 0) + { + throw new FormatException( + $"Expected VECTOR({elementTypeName}) payload length to be a multiple of {elementSize}, but got {dataLength}."); + } + + if (columnDefinition.VectorDimensions is { } dimensions) + { + var expectedLength = checked((ulong) dimensions * (ulong) elementSize); + if ((ulong) dataLength != expectedLength) + { + throw new FormatException( + $"Expected VECTOR({dimensions}, {elementTypeName}) payload length to be {expectedLength} bytes, but got {dataLength}."); + } + } + } +} + +internal sealed class VectorInt8ColumnReader : VectorColumnReaderBase +{ + public static VectorInt8ColumnReader Instance { get; } = new(); + + public override object ReadValue(ReadOnlySpan data, ColumnDefinitionPayload columnDefinition) + { + ValidateLength(columnDefinition, data.Length, sizeof(sbyte), "I8"); + return new ReadOnlyMemory(MemoryMarshal.Cast(data).ToArray()); + } +} + +internal sealed class VectorInt16ColumnReader : VectorColumnReaderBase +{ + public static VectorInt16ColumnReader Instance { get; } = new(); + + public override object ReadValue(ReadOnlySpan data, ColumnDefinitionPayload columnDefinition) + { + ValidateLength(columnDefinition, data.Length, sizeof(short), "I16"); + + if (BitConverter.IsLittleEndian) + return new ReadOnlyMemory(MemoryMarshal.Cast(data).ToArray()); + + var values = new short[data.Length / sizeof(short)]; + for (var i = 0; i < values.Length; i++) + values[i] = BinaryPrimitives.ReadInt16LittleEndian(data.Slice(i * sizeof(short), sizeof(short))); + + return new ReadOnlyMemory(values); + } +} + +internal sealed class VectorInt32ColumnReader : VectorColumnReaderBase +{ + public static VectorInt32ColumnReader Instance { get; } = new(); + + public override object ReadValue(ReadOnlySpan data, ColumnDefinitionPayload columnDefinition) + { + ValidateLength(columnDefinition, data.Length, sizeof(int), "I32"); + + if (BitConverter.IsLittleEndian) + return new ReadOnlyMemory(MemoryMarshal.Cast(data).ToArray()); + + var values = new int[data.Length / sizeof(int)]; + for (var i = 0; i < values.Length; i++) + values[i] = BinaryPrimitives.ReadInt32LittleEndian(data.Slice(i * sizeof(int), sizeof(int))); + + return new ReadOnlyMemory(values); + } +} + +internal sealed class VectorInt64ColumnReader : VectorColumnReaderBase +{ + public static VectorInt64ColumnReader Instance { get; } = new(); + + public override object ReadValue(ReadOnlySpan data, ColumnDefinitionPayload columnDefinition) + { + ValidateLength(columnDefinition, data.Length, sizeof(long), "I64"); + + if (BitConverter.IsLittleEndian) + return new ReadOnlyMemory(MemoryMarshal.Cast(data).ToArray()); + + var values = new long[data.Length / sizeof(long)]; + for (var i = 0; i < values.Length; i++) + values[i] = BinaryPrimitives.ReadInt64LittleEndian(data.Slice(i * sizeof(long), sizeof(long))); + + return new ReadOnlyMemory(values); + } +} + +internal sealed class VectorFloat32ColumnReader : VectorColumnReaderBase +{ + public static VectorFloat32ColumnReader Instance { get; } = new(); + + public override object ReadValue(ReadOnlySpan data, ColumnDefinitionPayload columnDefinition) + { + ValidateLength(columnDefinition, data.Length, sizeof(float), "F32"); + + if (BitConverter.IsLittleEndian) + return new ReadOnlyMemory(MemoryMarshal.Cast(data).ToArray()); + + var values = new float[data.Length / sizeof(float)]; + +#if NET5_0_OR_GREATER + for (var i = 0; i < values.Length; i++) + values[i] = BinaryPrimitives.ReadSingleLittleEndian(data.Slice(i * sizeof(float), sizeof(float))); +#else + var bytes = data.ToArray(); + for (var i = 0; i < values.Length; i++) + { + Array.Reverse(bytes, i * sizeof(float), sizeof(float)); + values[i] = BitConverter.ToSingle(bytes, i * sizeof(float)); + } +#endif + + return new ReadOnlyMemory(values); + } +} + +internal sealed class VectorFloat64ColumnReader : VectorColumnReaderBase +{ + public static VectorFloat64ColumnReader Instance { get; } = new(); + + public override object ReadValue(ReadOnlySpan data, ColumnDefinitionPayload columnDefinition) + { + ValidateLength(columnDefinition, data.Length, sizeof(double), "F64"); + + if (BitConverter.IsLittleEndian) + return new ReadOnlyMemory(MemoryMarshal.Cast(data).ToArray()); + + var values = new double[data.Length / sizeof(double)]; + +#if NET5_0_OR_GREATER + for (var i = 0; i < values.Length; i++) + values[i] = BinaryPrimitives.ReadDoubleLittleEndian(data.Slice(i * sizeof(double), sizeof(double))); +#else + var bytes = data.ToArray(); + for (var i = 0; i < values.Length; i++) + { + Array.Reverse(bytes, i * sizeof(double), sizeof(double)); + values[i] = BitConverter.ToDouble(bytes, i * sizeof(double)); + } +#endif + + return new ReadOnlyMemory(values); + } +} diff --git a/src/SingleStoreConnector/Core/ConnectionSettings.cs b/src/SingleStoreConnector/Core/ConnectionSettings.cs index c6841c8f8..9fd08103a 100644 --- a/src/SingleStoreConnector/Core/ConnectionSettings.cs +++ b/src/SingleStoreConnector/Core/ConnectionSettings.cs @@ -149,6 +149,8 @@ public ConnectionSettings(SingleStoreConnectionStringBuilder csb) UseAffectedRows = csb.UseAffectedRows; UseCompression = csb.UseCompression; UseXaTransactions = false; + EnableExtendedDataTypes = csb.EnableExtendedDataTypes; + EnableExtendedDataTypesWasExplicitlySet = csb.ContainsKey("Enable Extended Data Types"); static int ToSigned(uint value) => value >= int.MaxValue ? int.MaxValue : (int) value; } @@ -248,6 +250,8 @@ private static SingleStoreGuidFormat GetEffectiveGuidFormat(SingleStoreGuidForma public bool UseAffectedRows { get; } public bool UseCompression { get; } public bool UseXaTransactions { get; } + public bool EnableExtendedDataTypes { get; } + internal bool EnableExtendedDataTypesWasExplicitlySet { get; } public string ConnAttrsExtra { get; set; } public byte[]? ConnectionAttributes { get; set; } @@ -341,6 +345,8 @@ private ConnectionSettings(ConnectionSettings other, string host, int port, stri UseAffectedRows = other.UseAffectedRows; UseCompression = other.UseCompression; UseXaTransactions = other.UseXaTransactions; + EnableExtendedDataTypes = other.EnableExtendedDataTypes; + EnableExtendedDataTypesWasExplicitlySet = other.EnableExtendedDataTypesWasExplicitlySet; } private static readonly string[] s_localhostPipeServer = ["."]; diff --git a/src/SingleStoreConnector/Core/Row.cs b/src/SingleStoreConnector/Core/Row.cs index 4ea5ae3a9..23dbb46e1 100644 --- a/src/SingleStoreConnector/Core/Row.cs +++ b/src/SingleStoreConnector/Core/Row.cs @@ -2,6 +2,7 @@ using System.Text; using SingleStoreConnector.ColumnReaders; using SingleStoreConnector.Protocol; +using SingleStoreConnector.Protocol.Payloads; using SingleStoreConnector.Protocol.Serialization; #if !NETCOREAPP2_1_OR_GREATER && !NETSTANDARD2_1_OR_GREATER using SingleStoreConnector.Utilities; @@ -449,6 +450,16 @@ private void CheckBinaryColumn(int ordinal) throw new InvalidCastException("Column is NULL."); var column = ResultSet.ColumnDefinitions![ordinal]; + + switch (column.ExtendedTypeCode) + { + case SingleStoreExtendedTypeCode.Bson: + return; + + case SingleStoreExtendedTypeCode.Vector: + throw new InvalidCastException("Can't convert VECTOR to bytes."); + } + var columnType = column.ColumnType; if ((column.ColumnFlags & ColumnFlags.Binary) == 0 || (columnType != ColumnType.String && columnType != ColumnType.VarString && columnType != ColumnType.TinyBlob && diff --git a/src/SingleStoreConnector/Core/ServerSession.cs b/src/SingleStoreConnector/Core/ServerSession.cs index 4c8e224cf..635a591dc 100644 --- a/src/SingleStoreConnector/Core/ServerSession.cs +++ b/src/SingleStoreConnector/Core/ServerSession.cs @@ -802,6 +802,14 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, SingleSto payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); OkPayload.Verify(payload.Span, this); + // re-enable extended types metadata if needed + if (cs.EnableExtendedDataTypes && S2ServerVersion.Version.CompareTo(S2Versions.SupportsExtendedDataTypes) >= 0) + { + await SendAsync(QueryPayload.Create(SupportsQueryAttributes, Encoding.ASCII.GetBytes("SET SESSION enable_extended_types_metadata = TRUE;")), ioBehavior, cancellationToken).ConfigureAwait(false); + payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); + OkPayload.Verify(payload.Span, this); + } + return true; } catch (IOException ex) diff --git a/src/SingleStoreConnector/Core/ServerVersions.cs b/src/SingleStoreConnector/Core/ServerVersions.cs index 8897c8ba8..442fb1fa2 100644 --- a/src/SingleStoreConnector/Core/ServerVersions.cs +++ b/src/SingleStoreConnector/Core/ServerVersions.cs @@ -23,4 +23,5 @@ internal static class S2Versions public static readonly Version SupportsUtf8Mb4 = new(7, 5, 0); public static readonly Version SupportsResetConnection = new(7, 5, 0); public static readonly Version HasDataConversionCompatibilityLevelParameter = new(8, 0, 0); + public static readonly Version SupportsExtendedDataTypes = new(8, 5, 28); } diff --git a/src/SingleStoreConnector/Core/SingleStoreBinaryValueConverter.cs b/src/SingleStoreConnector/Core/SingleStoreBinaryValueConverter.cs new file mode 100644 index 000000000..be93cd380 --- /dev/null +++ b/src/SingleStoreConnector/Core/SingleStoreBinaryValueConverter.cs @@ -0,0 +1,162 @@ +using System.Buffers.Binary; +using System.Runtime.InteropServices; + +namespace SingleStoreConnector.Core; + +internal static class SingleStoreBinaryValueConverter +{ + public static bool TryInferSpecialSingleStoreDbType(object value, out SingleStoreDbType dbType) + { + // Use explicit type checks instead of pattern matching to avoid byte[]/sbyte[] confusion + var type = value.GetType(); + + // byte[] and related types should NOT infer as Vector - they use normal Blob type mapping + if (type == typeof(byte[]) || + type == typeof(ReadOnlyMemory) || + type == typeof(Memory) || + type == typeof(ArraySegment) || + value is MemoryStream) + { + dbType = default; + return false; + } + + // Numeric array types infer as Vector + if (type == typeof(float[]) || type == typeof(ReadOnlyMemory) || type == typeof(Memory) || + type == typeof(double[]) || type == typeof(ReadOnlyMemory) || type == typeof(Memory) || + type == typeof(sbyte[]) || type == typeof(ReadOnlyMemory) || type == typeof(Memory) || + type == typeof(short[]) || type == typeof(ReadOnlyMemory) || type == typeof(Memory) || + type == typeof(int[]) || type == typeof(ReadOnlyMemory) || type == typeof(Memory) || + type == typeof(long[]) || type == typeof(ReadOnlyMemory) || type == typeof(Memory)) + { + dbType = SingleStoreDbType.Vector; + return true; + } + + dbType = default; + return false; + } + + public static ReadOnlySpan GetBsonBytes(object value) => + GetRawBytes(value, SingleStoreDbType.Bson); + + public static ReadOnlySpan GetVectorBytes(object value) => + value switch + { + float[] x => ConvertFloatsToBytes(x.AsSpan()), + ReadOnlyMemory x => ConvertFloatsToBytes(x.Span), + Memory x => ConvertFloatsToBytes(x.Span), + + double[] x => ConvertDoublesToBytes(x.AsSpan()), + ReadOnlyMemory x => ConvertDoublesToBytes(x.Span), + Memory x => ConvertDoublesToBytes(x.Span), + + sbyte[] x => MemoryMarshal.AsBytes(x.AsSpan()), + ReadOnlyMemory x => MemoryMarshal.AsBytes(x.Span), + Memory x => MemoryMarshal.AsBytes(x.Span), + + short[] x => ConvertInt16ToBytes(x.AsSpan()), + ReadOnlyMemory x => ConvertInt16ToBytes(x.Span), + Memory x => ConvertInt16ToBytes(x.Span), + + int[] x => ConvertInt32ToBytes(x.AsSpan()), + ReadOnlyMemory x => ConvertInt32ToBytes(x.Span), + Memory x => ConvertInt32ToBytes(x.Span), + + long[] x => ConvertInt64ToBytes(x.AsSpan()), + ReadOnlyMemory x => ConvertInt64ToBytes(x.Span), + Memory x => ConvertInt64ToBytes(x.Span), + + byte[] or ReadOnlyMemory or Memory or ArraySegment or MemoryStream + => GetRawBytes(value, SingleStoreDbType.Vector), + + _ => throw new NotSupportedException( + $"Parameter type {value.GetType().Name} is not supported for SingleStoreDbType.Vector."), + }; + + public static ReadOnlySpan ConvertFloatsToBytes(ReadOnlySpan values) + { + if (BitConverter.IsLittleEndian) + { + return MemoryMarshal.AsBytes(values); + } + else + { + // for big-endian platforms, we need to convert each float individually + var bytes = new byte[values.Length * 4]; + + for (var i = 0; i < values.Length; i++) + { +#if NET5_0_OR_GREATER + BinaryPrimitives.WriteSingleLittleEndian(bytes.AsSpan(i * 4), values[i]); +#else + var floatBytes = BitConverter.GetBytes(values[i]); + Array.Reverse(floatBytes); + floatBytes.CopyTo(bytes, i * 4); +#endif + } + + return bytes; + } + } + + private static ReadOnlySpan ConvertDoublesToBytes(ReadOnlySpan values) + { + if (BitConverter.IsLittleEndian) + return MemoryMarshal.AsBytes(values); + + var bytes = new byte[values.Length * sizeof(double)]; + for (var i = 0; i < values.Length; i++) + { + var valueBytes = BitConverter.GetBytes(values[i]); + Array.Reverse(valueBytes); + valueBytes.CopyTo(bytes, i * sizeof(double)); + } + return bytes; + } + + private static ReadOnlySpan ConvertInt16ToBytes(ReadOnlySpan values) + { + if (BitConverter.IsLittleEndian) + return MemoryMarshal.AsBytes(values); + + var bytes = new byte[values.Length * sizeof(short)]; + for (var i = 0; i < values.Length; i++) + BinaryPrimitives.WriteInt16LittleEndian(bytes.AsSpan(i * sizeof(short), sizeof(short)), values[i]); + return bytes; + } + + private static ReadOnlySpan ConvertInt32ToBytes(ReadOnlySpan values) + { + if (BitConverter.IsLittleEndian) + return MemoryMarshal.AsBytes(values); + + var bytes = new byte[values.Length * sizeof(int)]; + for (var i = 0; i < values.Length; i++) + BinaryPrimitives.WriteInt32LittleEndian(bytes.AsSpan(i * sizeof(int), sizeof(int)), values[i]); + return bytes; + } + + private static ReadOnlySpan ConvertInt64ToBytes(ReadOnlySpan values) + { + if (BitConverter.IsLittleEndian) + return MemoryMarshal.AsBytes(values); + + var bytes = new byte[values.Length * sizeof(long)]; + for (var i = 0; i < values.Length; i++) + BinaryPrimitives.WriteInt64LittleEndian(bytes.AsSpan(i * sizeof(long), sizeof(long)), values[i]); + return bytes; + } + + private static ReadOnlySpan GetRawBytes(object value, SingleStoreDbType dbType) => + value switch + { + byte[] x => x, + ReadOnlyMemory x => x.Span, + Memory x => x.Span, + ArraySegment x => x.AsSpan(), + MemoryStream x => x.TryGetBuffer(out var buffer) ? buffer.AsSpan() : x.ToArray(), + _ => throw new NotSupportedException( + $"Parameter type {value.GetType().Name} is not supported for {dbType}."), + }; +} diff --git a/src/SingleStoreConnector/Core/TypeMapper.cs b/src/SingleStoreConnector/Core/TypeMapper.cs index 27abe20d2..31e27eea8 100644 --- a/src/SingleStoreConnector/Core/TypeMapper.cs +++ b/src/SingleStoreConnector/Core/TypeMapper.cs @@ -78,6 +78,12 @@ private TypeMapper() AddColumnTypeMetadata(new("MEDIUMBLOB", typeBinary, SingleStoreDbType.MediumBlob, binary: true, columnSize: 16777215, simpleDataTypeName: "BLOB")); AddColumnTypeMetadata(new("LONGBLOB", typeBinary, SingleStoreDbType.LongBlob, binary: true, columnSize: uint.MaxValue, simpleDataTypeName: "BLOB")); + // bson + AddColumnTypeMetadata(new("BSON", typeBinary, SingleStoreDbType.Bson, binary: true, columnSize: uint.MaxValue, simpleDataTypeName: "BSON", createFormat: "BSON")); + + // VECTOR: provider/logical type, blob transport + AddColumnTypeMetadata(new("VECTOR", typeBinary, SingleStoreDbType.Vector, binary: true, simpleDataTypeName: "VECTOR")); + // spatial AddColumnTypeMetadata(new("GEOGRAPHY", typeString, SingleStoreDbType.Geography, columnSize: 1073741823)); AddColumnTypeMetadata(new("POINT", typeString, SingleStoreDbType.GeographyPoint, columnSize: 48)); @@ -127,6 +133,11 @@ private TypeMapper() public SingleStoreDbType GetSingleStoreDbTypeForDbType(DbType dbType) { + // DbType.Binary is ambiguous because Blob, Binary, VarBinary, Bson, and Vector + // all use binary transport. We'll stick to preserving the historical/default inference. + if (dbType == DbType.Binary) + return SingleStoreDbType.Blob; + foreach (var pair in m_mySqlDbTypeToColumnTypeMetadata) { if (pair.Value.DbTypeMapping.DbTypes.Contains(dbType)) @@ -198,6 +209,14 @@ public SingleStoreDbType GetSingleStoreDbType(string typeName, bool unsigned, in public static SingleStoreDbType ConvertToSingleStoreDbType(ColumnDefinitionPayload columnDefinition, bool treatTinyAsBoolean, bool treatChar48AsGeographyPoint, SingleStoreGuidFormat guidFormat) { + switch (columnDefinition.ExtendedTypeCode) + { + case SingleStoreExtendedTypeCode.Bson: + return SingleStoreDbType.Bson; + case SingleStoreExtendedTypeCode.Vector: + return SingleStoreDbType.Vector; + } + var isUnsigned = (columnDefinition.ColumnFlags & ColumnFlags.Unsigned) != 0; if ((columnDefinition.ColumnFlags & ColumnFlags.Enum) != 0) return SingleStoreDbType.Enum; @@ -340,6 +359,11 @@ public static ushort ConvertToColumnTypeAndFlags(SingleStoreDbType dbType, Singl SingleStoreDbType.Blob or SingleStoreDbType.Text => ColumnType.Blob, SingleStoreDbType.MediumBlob or SingleStoreDbType.MediumText => ColumnType.MediumBlob, SingleStoreDbType.LongBlob or SingleStoreDbType.LongText => ColumnType.LongBlob, + + // NEW: transport BSON and VECTOR as BLOB + SingleStoreDbType.Bson => ColumnType.Blob, + SingleStoreDbType.Vector => ColumnType.Blob, + SingleStoreDbType.JSON => ColumnType.Json, // TODO: test SingleStoreDbType.Date or SingleStoreDbType.Newdate => ColumnType.Date, SingleStoreDbType.DateTime => ColumnType.DateTime, diff --git a/src/SingleStoreConnector/Protocol/ColumnType.cs b/src/SingleStoreConnector/Protocol/ColumnType.cs index 0c033c822..f9b9f681c 100644 --- a/src/SingleStoreConnector/Protocol/ColumnType.cs +++ b/src/SingleStoreConnector/Protocol/ColumnType.cs @@ -1,8 +1,8 @@ namespace SingleStoreConnector.Protocol; -/// -/// See SingleStore documentation. -/// +/// Base column type values from the MySQL-compatible protocol. +/// SingleStore-specific types such as BSON and VECTOR are exposed through extended metadata, +/// not as additional ColumnType values. internal enum ColumnType { Decimal = 0, diff --git a/src/SingleStoreConnector/Protocol/Payloads/ColumnDefinitionPayload.cs b/src/SingleStoreConnector/Protocol/Payloads/ColumnDefinitionPayload.cs index 9e7f7a91c..4160a843d 100644 --- a/src/SingleStoreConnector/Protocol/Payloads/ColumnDefinitionPayload.cs +++ b/src/SingleStoreConnector/Protocol/Payloads/ColumnDefinitionPayload.cs @@ -4,6 +4,23 @@ namespace SingleStoreConnector.Protocol.Payloads; +internal enum SingleStoreExtendedTypeCode : byte +{ + None = 0, + Bson = 1, + Vector = 2, +} + +internal enum SingleStoreVectorElementType : byte +{ + F32 = 1, + F64 = 2, + I8 = 3, + I16 = 4, + I32 = 5, + I64 = 6, +} + internal sealed class ColumnDefinitionPayload { public string Name @@ -76,6 +93,14 @@ public string PhysicalName public byte Decimals { get; private set; } + public int FixedLengthFieldsLength { get; private set; } + + public SingleStoreExtendedTypeCode ExtendedTypeCode { get; private set; } + + public uint? VectorDimensions { get; private set; } + + public SingleStoreVectorElementType? VectorElementType { get; private set; } + public static void Initialize(ref ColumnDefinitionPayload payload, ResizableArraySegment arraySegment) { payload ??= new ColumnDefinitionPayload(); @@ -91,7 +116,12 @@ private void Initialize(ResizableArraySegment originalData) SkipLengthEncodedByteString(ref reader); // physical table SkipLengthEncodedByteString(ref reader); // name SkipLengthEncodedByteString(ref reader); // physical name - reader.ReadByte(0x0C); // length of fixed-length fields, always 0x0C + + FixedLengthFieldsLength = checked((int) reader.ReadLengthEncodedInteger()); + if (FixedLengthFieldsLength < 12) + throw new FormatException( + $"Expected fixed-length fields length to be at least 12 bytes, but was {FixedLengthFieldsLength}."); + CharacterSet = (CharacterSet) reader.ReadUInt16(); ColumnLength = reader.ReadUInt32(); ColumnType = (ColumnType) reader.ReadByte(); @@ -100,6 +130,38 @@ private void Initialize(ResizableArraySegment originalData) reader.ReadByte(0); // reserved byte 1 reader.ReadByte(0); // reserved byte 2 + ExtendedTypeCode = SingleStoreExtendedTypeCode.None; + VectorDimensions = null; + VectorElementType = null; + + var remainingExtendedBytes = FixedLengthFieldsLength - 12; + if (remainingExtendedBytes > 0) + { + ExtendedTypeCode = (SingleStoreExtendedTypeCode) reader.ReadByte(); + remainingExtendedBytes--; + + switch (ExtendedTypeCode) + { + case SingleStoreExtendedTypeCode.None: + case SingleStoreExtendedTypeCode.Bson: + break; + case SingleStoreExtendedTypeCode.Vector: + if (remainingExtendedBytes < 5) + throw new FormatException( + $"Expected 5 additional bytes for VECTOR extended metadata, but only {remainingExtendedBytes} remained."); + + VectorDimensions = reader.ReadUInt32(); + VectorElementType = (SingleStoreVectorElementType) reader.ReadByte(); + remainingExtendedBytes -= 5; + break; + default: + break; + } + + if (remainingExtendedBytes > 0) + reader.Offset += remainingExtendedBytes; + } + if (m_readNames) { m_catalogName = null; diff --git a/src/SingleStoreConnector/SingleStoreBulkCopy.cs b/src/SingleStoreConnector/SingleStoreBulkCopy.cs index a7b079ca1..d7e177454 100644 --- a/src/SingleStoreConnector/SingleStoreBulkCopy.cs +++ b/src/SingleStoreConnector/SingleStoreBulkCopy.cs @@ -234,10 +234,40 @@ private async ValueTask WriteToServerAsync(IOBehavior for (var i = 0; i < schema.Count; i++) { var destinationColumn = reader.GetName(i); + var variableName = $"@`temporary_column_dotnet_connector_col{i}`"; + + if (schema[i] is not SingleStoreDbColumn singleStoreColumn) + { + // fallback to existing behavior + goto LegacyHandling; + } + + switch (singleStoreColumn.ProviderType) + { + case SingleStoreDbType.Vector: + { + if (singleStoreColumn.VectorDimensions is not { } dims || string.IsNullOrEmpty(singleStoreColumn.VectorElementTypeName)) + throw new InvalidOperationException( + $"VECTOR destination column '{destinationColumn}' is missing dimension or element type metadata."); + + var expression = $"%COL% = UNHEX(%VAR%):>VECTOR({dims}, {singleStoreColumn.VectorElementTypeName})"; + AddColumnMapping(m_logger, columnMappings, addDefaultMappings, i, destinationColumn, variableName, expression); + continue; + } + + case SingleStoreDbType.Bson: + { + var expression = "%COL% = UNHEX(%VAR%):>BSON"; + AddColumnMapping(m_logger, columnMappings, addDefaultMappings, i, destinationColumn, variableName, expression); + continue; + } + } + + LegacyHandling: var dataTypeName = schema[i].DataTypeName; if (dataTypeName == "BIT") { - AddColumnMapping(m_logger, columnMappings, addDefaultMappings, i, destinationColumn, $"@`temporary_column_dotnet_connector_col{i}`", $"%COL% = CAST(%VAR% AS UNSIGNED)"); + AddColumnMapping(m_logger, columnMappings, addDefaultMappings, i, destinationColumn, variableName, $"%COL% = CAST(%VAR% AS UNSIGNED)"); } else { @@ -245,7 +275,7 @@ private async ValueTask WriteToServerAsync(IOBehavior if (type == typeof(byte[]) || (type == typeof(Guid) && (m_connection.GuidFormat is SingleStoreGuidFormat.Binary16 or SingleStoreGuidFormat.LittleEndianBinary16 or SingleStoreGuidFormat.TimeSwapBinary16))) { - AddColumnMapping(m_logger, columnMappings, addDefaultMappings, i, destinationColumn, $"@`temporary_column_dotnet_connector_col{i}`", $"%COL% = UNHEX(%VAR%)"); + AddColumnMapping(m_logger, columnMappings, addDefaultMappings, i, destinationColumn, variableName, $"%COL% = UNHEX(%VAR%)"); } else if (addDefaultMappings) { @@ -473,17 +503,31 @@ static bool WriteValue(SingleStoreConnection connection, object value, ref int i { return Utf8Formatter.TryFormat(decimalValue, output, out bytesWritten); } - else if (value is byte[] or ReadOnlyMemory or Memory or ArraySegment or float[] or ReadOnlyMemory or Memory) + else if (value is byte[] or ReadOnlyMemory or Memory or ArraySegment or MemoryStream or + float[] or ReadOnlyMemory or Memory or + double[] or ReadOnlyMemory or Memory or + sbyte[] or ReadOnlyMemory or Memory or + short[] or ReadOnlyMemory or Memory or + int[] or ReadOnlyMemory or Memory or + long[] or ReadOnlyMemory or Memory) { var inputSpan = value switch { byte[] byteArray => byteArray.AsSpan(), ArraySegment arraySegment => arraySegment.AsSpan(), Memory memory => memory.Span, - float[] floatArray => SingleStoreParameter.ConvertFloatsToBytes(floatArray.AsSpan()), - Memory memory => SingleStoreParameter.ConvertFloatsToBytes(memory.Span), - ReadOnlyMemory memory => SingleStoreParameter.ConvertFloatsToBytes(memory.Span), - _ => ((ReadOnlyMemory) value).Span, + ReadOnlyMemory memory => memory.Span, + MemoryStream memoryStream => memoryStream.TryGetBuffer(out var streamBuffer) ? streamBuffer.AsSpan() : memoryStream.ToArray().AsSpan(), + + float[] or ReadOnlyMemory or Memory or + double[] or ReadOnlyMemory or Memory or + sbyte[] or ReadOnlyMemory or Memory or + short[] or ReadOnlyMemory or Memory or + int[] or ReadOnlyMemory or Memory or + long[] or ReadOnlyMemory or Memory + => SingleStoreBinaryValueConverter.GetVectorBytes(value), + + _ => throw new NotSupportedException($"Type {value.GetType().Name} not currently supported. Value: {value}") }; return WriteBytes(inputSpan, ref inputIndex, output, out bytesWritten); diff --git a/src/SingleStoreConnector/SingleStoreConnection.cs b/src/SingleStoreConnector/SingleStoreConnection.cs index fcb2e8e7a..da473d9bf 100644 --- a/src/SingleStoreConnector/SingleStoreConnection.cs +++ b/src/SingleStoreConnector/SingleStoreConnection.cs @@ -513,6 +513,31 @@ private async Task ChangeDatabaseAsync(IOBehavior ioBehavior, string databaseNam public new SingleStoreCommand CreateCommand() => (SingleStoreCommand) base.CreateCommand(); + private async Task InitializeSessionAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) + { + var settings = GetInitializedConnectionSettings(); + + if (!settings.EnableExtendedDataTypes) + return; + + if (Session.S2ServerVersion.Version.CompareTo(S2Versions.SupportsExtendedDataTypes) < 0) + { + if (settings.EnableExtendedDataTypesWasExplicitlySet) + { + throw new NotSupportedException( + "EnableExtendedDataTypes requires SingleStore 8.5.28 or later."); + } + + return; + } + + await using var cmd = new SingleStoreCommand( + "SET SESSION enable_extended_types_metadata = TRUE;", + this); + + await cmd.ExecuteNonQueryAsync(ioBehavior, cancellationToken).ConfigureAwait(false); + } + #pragma warning disable CA2012 // Safe because method completes synchronously public bool Ping() => PingAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); #pragma warning restore CA2012 @@ -572,6 +597,16 @@ internal async Task OpenAsync(IOBehavior? ioBehavior, CancellationToken cancella m_hasBeenOpened = true; SetState(ConnectionState.Open); + try + { + await InitializeSessionAsync(ioBehavior ?? AsyncIOBehavior, cancellationToken).ConfigureAwait(false); + } + catch + { + await CloseAsync(changeState: true, ioBehavior ?? AsyncIOBehavior).ConfigureAwait(false); + throw; + } + if (ConnectionOpenedCallback is { } autoEnlistConnectionOpenedCallback) { cancellationToken.ThrowIfCancellationRequested(); @@ -588,6 +623,16 @@ internal async Task OpenAsync(IOBehavior? ioBehavior, CancellationToken cancella cancellationToken).ConfigureAwait(false); m_hasBeenOpened = true; SetState(ConnectionState.Open); + + try + { + await InitializeSessionAsync(ioBehavior ?? AsyncIOBehavior, cancellationToken).ConfigureAwait(false); + } + catch + { + await CloseAsync(changeState: true, ioBehavior ?? AsyncIOBehavior).ConfigureAwait(false); + throw; + } } catch (OperationCanceledException ex) { @@ -1074,6 +1119,7 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel) internal bool IgnorePrepare => GetInitializedConnectionSettings().IgnorePrepare; internal bool NoBackslashEscapes => GetInitializedConnectionSettings().NoBackslashEscapes; internal bool TreatTinyAsBoolean => GetInitializedConnectionSettings().TreatTinyAsBoolean; + internal bool EnableExtendedDataTypes => GetInitializedConnectionSettings().EnableExtendedDataTypes; internal IOBehavior AsyncIOBehavior => GetConnectionSettings().ForceSynchronous ? IOBehavior.Synchronous diff --git a/src/SingleStoreConnector/SingleStoreConnectionStringBuilder.cs b/src/SingleStoreConnector/SingleStoreConnectionStringBuilder.cs index 7e2701f64..28d7f1a1e 100644 --- a/src/SingleStoreConnector/SingleStoreConnectionStringBuilder.cs +++ b/src/SingleStoreConnector/SingleStoreConnectionStringBuilder.cs @@ -827,6 +827,19 @@ public bool UseXaTransactions set => SingleStoreConnectionStringOption.UseXaTransactions.SetValue(this, value); } + /// + /// Enables extended data types, by enabling the enable_extended_types_metadata engine variable, that allows the connector to support extended data types, such as VECTOR and BSON + /// + [Category("Other")] + [DefaultValue(true)] + [Description("Enable extended data types engine variable for VECTOR and BSON support.")] + [DisplayName("Enable extended data types")] + public bool EnableExtendedDataTypes + { + get => SingleStoreConnectionStringOption.EnableExtendedDataTypes.GetValue(this); + set => SingleStoreConnectionStringOption.EnableExtendedDataTypes.SetValue(this, value); + } + // Other Methods /// @@ -987,6 +1000,7 @@ internal abstract partial class SingleStoreConnectionStringOption public static readonly SingleStoreConnectionStringValueOption UseAffectedRows; public static readonly SingleStoreConnectionStringValueOption UseCompression; public static readonly SingleStoreConnectionStringValueOption UseXaTransactions; + public static readonly SingleStoreConnectionStringValueOption EnableExtendedDataTypes; public static SingleStoreConnectionStringOption? TryGetOptionForKey(string key) => s_options.TryGetValue(key, out var option) ? option : null; @@ -1299,6 +1313,10 @@ static SingleStoreConnectionStringOption() AddOption(options, UseXaTransactions = new( keys: ["Use XA Transactions", "UseXaTransactions"], defaultValue: true)); + + AddOption(options, EnableExtendedDataTypes = new( + keys: ["Enable Extended Data Types", "EnableExtendedDataTypes"], + defaultValue: true)); #pragma warning restore SA1118 // Parameter should not span multiple lines #if NET8_0_OR_GREATER diff --git a/src/SingleStoreConnector/SingleStoreDbColumn.cs b/src/SingleStoreConnector/SingleStoreDbColumn.cs index 6f9b9ccf7..07b8d1fa0 100644 --- a/src/SingleStoreConnector/SingleStoreDbColumn.cs +++ b/src/SingleStoreConnector/SingleStoreDbColumn.cs @@ -13,11 +13,50 @@ internal SingleStoreDbColumn(int ordinal, ColumnDefinitionPayload column, bool a var columnTypeMetadata = TypeMapper.Instance.GetColumnTypeMetadata(mySqlDbType); var type = columnTypeMetadata.DbTypeMapping.ClrType; + var dataTypeName = columnTypeMetadata.SimpleDataTypeName; + VectorDimensions = null; + VectorElementTypeName = null; + + switch (mySqlDbType) + { + case SingleStoreDbType.Bson: + type = typeof(byte[]); + dataTypeName = "BSON"; + break; + + case SingleStoreDbType.Vector: + dataTypeName = "VECTOR"; + + VectorDimensions = column.VectorDimensions is { } dims + ? checked((int) dims) + : null; + + VectorElementTypeName = column.VectorElementType?.ToString(); + + type = column.VectorElementType switch + { + SingleStoreVectorElementType.F32 => typeof(ReadOnlyMemory), + SingleStoreVectorElementType.F64 => typeof(ReadOnlyMemory), + SingleStoreVectorElementType.I8 => typeof(ReadOnlyMemory), + SingleStoreVectorElementType.I16 => typeof(ReadOnlyMemory), + SingleStoreVectorElementType.I32 => typeof(ReadOnlyMemory), + SingleStoreVectorElementType.I64 => typeof(ReadOnlyMemory), + null => throw new FormatException("VECTOR column is missing VectorElementType metadata."), + _ => throw new NotSupportedException( + $"Unsupported VECTOR element type: {column.VectorElementType}."), + }; + break; + } + + if (mySqlDbType == SingleStoreDbType.Vector && VectorDimensions is { } vectorDimensions) + { + ColumnSize = vectorDimensions; + } // starting from 7.8 SingleStore returns number of characters (not amount of bytes) // for text types (e.g. Text, TinyText, MediumText, LongText) // (see https://grizzly.internal.memcompute.com/D54237) - if (serverVersion >= new Version(7, 8, 0) && + else if (serverVersion >= new Version(7, 8, 0) && mySqlDbType is SingleStoreDbType.LongText or SingleStoreDbType.MediumText or SingleStoreDbType.Text or SingleStoreDbType.TinyText) { // overflow may occur here for SingleStoreDbType.LongText @@ -46,16 +85,21 @@ internal SingleStoreDbColumn(int ordinal, ColumnDefinitionPayload column, bool a ColumnName = column.Name; ColumnOrdinal = ordinal; DataType = (allowZeroDateTime && type == typeof(DateTime)) ? typeof(SingleStoreDateTime) : type; - DataTypeName = columnTypeMetadata.SimpleDataTypeName; + DataTypeName = dataTypeName; if (mySqlDbType == SingleStoreDbType.String) DataTypeName += string.Format(CultureInfo.InvariantCulture, "({0})", ColumnSize); + else if (mySqlDbType == SingleStoreDbType.Vector && column is { VectorDimensions: { } dimensions, VectorElementType: { } elementType }) + { + DataTypeName += string.Format(CultureInfo.InvariantCulture, "({0}, {1})", dimensions, elementType); + } IsAliased = column.PhysicalName != column.Name; IsAutoIncrement = (column.ColumnFlags & ColumnFlags.AutoIncrement) != 0; IsExpression = false; IsHidden = false; IsKey = (column.ColumnFlags & ColumnFlags.PrimaryKey) != 0; - IsLong = column.ColumnLength > 255 && - ((column.ColumnFlags & ColumnFlags.Blob) != 0 || column.ColumnType is ColumnType.TinyBlob or ColumnType.Blob or ColumnType.MediumBlob or ColumnType.LongBlob); + IsLong = mySqlDbType != SingleStoreDbType.Vector && + column.ColumnLength > 255 && + ((column.ColumnFlags & ColumnFlags.Blob) != 0 || column.ColumnType is ColumnType.TinyBlob or ColumnType.Blob or ColumnType.MediumBlob or ColumnType.LongBlob); IsReadOnly = false; IsUnique = (column.ColumnFlags & ColumnFlags.UniqueKey) != 0; if (column.ColumnType is ColumnType.Decimal or ColumnType.NewDecimal) @@ -73,6 +117,10 @@ internal SingleStoreDbColumn(int ordinal, ColumnDefinitionPayload column, bool a public SingleStoreDbType ProviderType { get; } + public int? VectorDimensions { get; } + + public string? VectorElementTypeName { get; } + /// /// Gets the name of the table that the column belongs to. This will be the alias if the table is aliased in the query. /// diff --git a/src/SingleStoreConnector/SingleStoreDbType.cs b/src/SingleStoreConnector/SingleStoreDbType.cs index 0375eba9f..4214cb85a 100644 --- a/src/SingleStoreConnector/SingleStoreDbType.cs +++ b/src/SingleStoreConnector/SingleStoreDbType.cs @@ -50,4 +50,8 @@ public enum SingleStoreDbType LongText, Text, Guid = 800, + + // SingleStore logical/provider-only types backed by extended metadata. + Bson = 801, + Vector = 802, } diff --git a/src/SingleStoreConnector/SingleStoreParameter.cs b/src/SingleStoreConnector/SingleStoreParameter.cs index a32f1c483..ed3997d99 100644 --- a/src/SingleStoreConnector/SingleStoreParameter.cs +++ b/src/SingleStoreConnector/SingleStoreParameter.cs @@ -139,11 +139,19 @@ public override object? Value m_value = value; if (!HasSetDbType && value is not null) { - var typeMapping = TypeMapper.Instance.GetDbTypeMapping(value.GetType()); - if (typeMapping is not null) + if (SingleStoreBinaryValueConverter.TryInferSpecialSingleStoreDbType(value, out var specialDbType)) { - m_dbType = typeMapping.DbTypes[0]; - m_mySqlDbType = TypeMapper.Instance.GetSingleStoreDbTypeForDbType(m_dbType); + m_dbType = TypeMapper.Instance.GetDbTypeForSingleStoreDbType(specialDbType); + m_mySqlDbType = specialDbType; + } + else + { + var typeMapping = TypeMapper.Instance.GetDbTypeMapping(value.GetType()); + if (typeMapping is not null) + { + m_dbType = typeMapping.DbTypes[0]; + m_mySqlDbType = TypeMapper.Instance.GetSingleStoreDbTypeForDbType(m_dbType); + } } } } @@ -203,13 +211,20 @@ private SingleStoreParameter(SingleStoreParameter other, string parameterName) /// internal void AppendSqlString(ByteBufferWriter writer, StatementPreparerOptions options) { - const byte backslash = 0x5C, quote = 0x27, zeroByte = 0x00; var noBackslashEscapes = (options & StatementPreparerOptions.NoBackslashEscapes) == StatementPreparerOptions.NoBackslashEscapes; if (Value is null || Value == DBNull.Value) { writer.Write("NULL"u8); } + else if (SingleStoreDbType == SingleStoreDbType.Vector) + { + WriteBinaryLiteral(writer, noBackslashEscapes, SingleStoreBinaryValueConverter.GetVectorBytes(Value!)); + } + else if (SingleStoreDbType == SingleStoreDbType.Bson) + { + WriteBinaryLiteral(writer, noBackslashEscapes, SingleStoreBinaryValueConverter.GetBsonBytes(Value!)); + } else if (Value is string stringValue) { WriteString(writer, noBackslashEscapes, stringValue.AsSpan()); @@ -296,40 +311,13 @@ internal void AppendSqlString(ByteBufferWriter writer, StatementPreparerOptions ArraySegment arraySegment => arraySegment.AsSpan(), Memory memory => memory.Span, MemoryStream memoryStream => memoryStream.TryGetBuffer(out var streamBuffer) ? streamBuffer.AsSpan() : memoryStream.ToArray().AsSpan(), - float[] floatArray => ConvertFloatsToBytes(floatArray.AsSpan()), - Memory memory => ConvertFloatsToBytes(memory.Span), - ReadOnlyMemory memory => ConvertFloatsToBytes(memory.Span), + float[] floatArray => SingleStoreBinaryValueConverter.ConvertFloatsToBytes(floatArray.AsSpan()), + Memory memory => SingleStoreBinaryValueConverter.ConvertFloatsToBytes(memory.Span), + ReadOnlyMemory memory => SingleStoreBinaryValueConverter.ConvertFloatsToBytes(memory.Span), _ => ((ReadOnlyMemory) Value).Span, }; - // determine the number of bytes to be written - var length = inputSpan.Length + BinaryBytes.Length + 1; - foreach (var by in inputSpan) - { - if (by is quote or zeroByte || (by is backslash && !noBackslashEscapes)) - length++; - } - - var outputSpan = writer.GetSpan(length); - BinaryBytes.CopyTo(outputSpan); - var index = BinaryBytes.Length; - foreach (var by in inputSpan) - { - if (by is zeroByte) - { - outputSpan[index++] = (byte) '\\'; - outputSpan[index++] = (byte) '0'; - } - else - { - if (by is quote || by is backslash && !noBackslashEscapes) - outputSpan[index++] = by; - outputSpan[index++] = by; - } - } - outputSpan[index++] = quote; - Debug.Assert(index == length, "index == length"); - writer.Advance(index); + WriteBinaryLiteral(writer, noBackslashEscapes, inputSpan); } else if (Value is SingleStoreGeography or SingleStoreGeographyPoint) { @@ -668,6 +656,22 @@ internal void AppendBinary(ByteBufferWriter writer, StatementPreparerOptions opt private void AppendBinary(ByteBufferWriter writer, object value, StatementPreparerOptions options) { + if (SingleStoreDbType == SingleStoreDbType.Vector) + { + var bytes = SingleStoreBinaryValueConverter.GetVectorBytes(value); + writer.WriteLengthEncodedInteger(unchecked((ulong) bytes.Length)); + writer.Write(bytes); + return; + } + + if (SingleStoreDbType == SingleStoreDbType.Bson) + { + var bytes = SingleStoreBinaryValueConverter.GetBsonBytes(value); + writer.WriteLengthEncodedInteger(unchecked((ulong) bytes.Length)); + writer.Write(bytes); + return; + } + if (value is string stringValue) { writer.WriteLengthEncodedString(stringValue); @@ -790,17 +794,17 @@ private void AppendBinary(ByteBufferWriter writer, object value, StatementPrepar else if (value is float[] floatArrayValue) { writer.WriteLengthEncodedInteger(unchecked((ulong) floatArrayValue.Length * 4)); - writer.Write(ConvertFloatsToBytes(floatArrayValue.AsSpan())); + writer.Write(SingleStoreBinaryValueConverter.ConvertFloatsToBytes(floatArrayValue.AsSpan())); } else if (value is Memory floatMemory) { writer.WriteLengthEncodedInteger(unchecked((ulong) floatMemory.Length * 4)); - writer.Write(ConvertFloatsToBytes(floatMemory.Span)); + writer.Write(SingleStoreBinaryValueConverter.ConvertFloatsToBytes(floatMemory.Span)); } else if (value is ReadOnlyMemory floatReadOnlyMemory) { writer.WriteLengthEncodedInteger(unchecked((ulong) floatReadOnlyMemory.Length * 4)); - writer.Write(ConvertFloatsToBytes(floatReadOnlyMemory.Span)); + writer.Write(SingleStoreBinaryValueConverter.ConvertFloatsToBytes(floatReadOnlyMemory.Span)); } else if (value is decimal decimalValue) { @@ -970,6 +974,42 @@ private static void WriteDateOnly(ByteBufferWriter writer, DateOnly dateOnly) } #endif + private static void WriteBinaryLiteral(ByteBufferWriter writer, bool noBackslashEscapes, ReadOnlySpan inputSpan) + { + const byte backslash = 0x5C, quote = 0x27, zeroByte = 0x00; + + var length = inputSpan.Length + BinaryBytes.Length + 1; + foreach (var by in inputSpan) + { + if (by is quote or zeroByte || (by is backslash && !noBackslashEscapes)) + length++; + } + + var outputSpan = writer.GetSpan(length); + BinaryBytes.CopyTo(outputSpan); + var index = BinaryBytes.Length; + + foreach (var by in inputSpan) + { + if (by is zeroByte) + { + outputSpan[index++] = (byte) '\\'; + outputSpan[index++] = (byte) '0'; + } + else + { + if (by is quote || (by is backslash && !noBackslashEscapes)) + outputSpan[index++] = by; + + outputSpan[index++] = by; + } + } + + outputSpan[index++] = quote; + Debug.Assert(index == length, "index == length"); + writer.Advance(index); + } + private static void WriteDateTime(ByteBufferWriter writer, DateTime dateTime) { byte length; @@ -1019,32 +1059,6 @@ private static void WriteTime(ByteBufferWriter writer, TimeSpan timeSpan) } } - internal static ReadOnlySpan ConvertFloatsToBytes(ReadOnlySpan floats) - { - if (BitConverter.IsLittleEndian) - { - return MemoryMarshal.AsBytes(floats); - } - else - { - // for big-endian platforms, we need to convert each float individually - var bytes = new byte[floats.Length * 4]; - - for (var i = 0; i < floats.Length; i++) - { -#if NET5_0_OR_GREATER - BinaryPrimitives.WriteSingleLittleEndian(bytes.AsSpan(i * 4), floats[i]); -#else - var floatBytes = BitConverter.GetBytes(floats[i]); - Array.Reverse(floatBytes); - floatBytes.CopyTo(bytes, i * 4); -#endif - } - - return bytes; - } - } - private static ReadOnlySpan BinaryBytes => "_binary'"u8; private DbType m_dbType; diff --git a/tests/SideBySide/ParameterTests.cs b/tests/SideBySide/ParameterTests.cs index 3579d463c..20f967207 100644 --- a/tests/SideBySide/ParameterTests.cs +++ b/tests/SideBySide/ParameterTests.cs @@ -1,3 +1,5 @@ +using SingleStoreConnector.Core; + namespace SideBySide; public class ParameterTests @@ -31,7 +33,7 @@ public void DbTypeToSingleStoreDbType(DbType dbType, SingleStoreDbType mySqlDbTy [InlineData(new[] { DbType.Date }, new[] { SingleStoreDbType.Date, SingleStoreDbType.Newdate })] #if !BASELINE [InlineData(new[] { DbType.Int32 }, new[] { SingleStoreDbType.Int32, SingleStoreDbType.Year })] - [InlineData(new[] { DbType.Binary }, new[] { SingleStoreDbType.Blob, SingleStoreDbType.Binary, SingleStoreDbType.TinyBlob, SingleStoreDbType.MediumBlob, SingleStoreDbType.LongBlob })] + [InlineData(new[] { DbType.Binary }, new[] { SingleStoreDbType.Blob, SingleStoreDbType.Binary, SingleStoreDbType.TinyBlob, SingleStoreDbType.MediumBlob, SingleStoreDbType.LongBlob, SingleStoreDbType.Bson, SingleStoreDbType.Vector })] [InlineData(new[] { DbType.String, DbType.AnsiString, DbType.Xml }, new[] { SingleStoreDbType.VarChar, SingleStoreDbType.VarString, SingleStoreDbType.Text, SingleStoreDbType.TinyText, SingleStoreDbType.MediumText, SingleStoreDbType.LongText, SingleStoreDbType.JSON, SingleStoreDbType.Enum, SingleStoreDbType.Set, SingleStoreDbType.Geography, SingleStoreDbType.GeographyPoint })] [InlineData(new[] { DbType.Decimal, DbType.Currency }, new[] { SingleStoreDbType.NewDecimal, SingleStoreDbType.Decimal })] @@ -345,6 +347,87 @@ public void SetValueDoesNotInferType() Assert.Equal(SingleStoreDbType.Int32, parameter.SingleStoreDbType); } + [Theory] + [MemberData(nameof(VectorParameterValues))] + public void SetValueToNumericArrayInfersVector(object value) + { + var parameter = new SingleStoreParameter { Value = value }; + + Assert.Equal(DbType.Binary, parameter.DbType); + Assert.Equal(SingleStoreDbType.Vector, parameter.SingleStoreDbType); + } + + public static IEnumerable VectorParameterValues() + { + yield return [new float[] { 1, 2 }]; + yield return [new double[] { 1, 2 }]; + yield return [new short[] { 1, 2 }]; + yield return [new int[] { 1, 2 }]; + yield return [new long[] { 1, 2 }]; + } + + [Fact] + public void SetValueToSByteArrayInfersVector() + { + var parameter = new SingleStoreParameter { Value = new sbyte[] { 1, 2 } }; + + Assert.Equal(DbType.Binary, parameter.DbType); + Assert.Equal(SingleStoreDbType.Vector, parameter.SingleStoreDbType); + } + + [Fact] + public void SetValueToByteArrayInfersBlob() + { + var parameter = new SingleStoreParameter { Value = new byte[] { 1, 2, 3 } }; + + Assert.Equal(DbType.Binary, parameter.DbType); + Assert.Equal(SingleStoreDbType.Blob, parameter.SingleStoreDbType); + } + + [Fact] + public void ExplicitVectorCanUseByteArray() + { + var bytes = new byte[] { 1, 2, 3 }; + + var parameter = new SingleStoreParameter + { + SingleStoreDbType = SingleStoreDbType.Vector, + Value = bytes, + }; + + Assert.Equal(DbType.Binary, parameter.DbType); + Assert.Equal(SingleStoreDbType.Vector, parameter.SingleStoreDbType); + Assert.Same(bytes, parameter.Value); + } + + [Fact] + public void SetValueToNumericArrayDoesNotOverrideExplicitType() + { + var parameter = new SingleStoreParameter + { + SingleStoreDbType = SingleStoreDbType.Blob, + }; + + parameter.Value = new int[] { 1, 2, 3 }; + + Assert.Equal(DbType.Binary, parameter.DbType); + Assert.Equal(SingleStoreDbType.Blob, parameter.SingleStoreDbType); + } + + [Theory] + [InlineData(SingleStoreDbType.Vector)] + [InlineData(SingleStoreDbType.Bson)] + public void ExtendedSingleStoreDbTypesUseBinaryDbType(SingleStoreDbType singleStoreDbType) + { + var parameter = new SingleStoreParameter + { + SingleStoreDbType = singleStoreDbType, + }; + + Assert.Equal(DbType.Binary, parameter.DbType); + Assert.Equal(singleStoreDbType, parameter.SingleStoreDbType); + } + [Fact] public void ResetDbType() { diff --git a/tests/SingleStoreConnector.Tests/SingleStoreConnectionStringBuilderTests.cs b/tests/SingleStoreConnector.Tests/SingleStoreConnectionStringBuilderTests.cs index 276d61521..238b9f6ba 100644 --- a/tests/SingleStoreConnector.Tests/SingleStoreConnectionStringBuilderTests.cs +++ b/tests/SingleStoreConnector.Tests/SingleStoreConnectionStringBuilderTests.cs @@ -92,6 +92,7 @@ public void Defaults() Assert.False(csb.UseAffectedRows); #if !BASELINE Assert.True(csb.UseXaTransactions); + Assert.True(csb.EnableExtendedDataTypes); #endif } @@ -160,7 +161,8 @@ public void ParseConnectionString() "ssl mode=verifyca;" + "tls version=Tls12, TLS v1.3;" + "Uid=username;" + - "useaffectedrows=true", + "useaffectedrows=true;" + + "enableextendeddatatypes=true", }; Assert.True(csb.AllowLoadLocalInfile); Assert.True(csb.AllowPublicKeyRetrieval); @@ -231,6 +233,7 @@ public void ParseConnectionString() #endif Assert.True(csb.UseAffectedRows); Assert.Equal("username", csb.UserID); + Assert.True(csb.EnableExtendedDataTypes); #if !BASELINE Assert.Equal("Server=db-server;Port=1234;User ID=username;Password=Pass1234;Database=schema_name;Load Balance=Random;" + @@ -245,7 +248,7 @@ public void ParseConnectionString() "TreatChar48AsGeographyPoint=True;GUID Format=TimeSwapBinary16;Ignore Command Transaction=True;Ignore Prepare=True;Interactive Session=True;" + "Keep Alive=90;No Backslash Escapes=True;Old Guids=True;Persist Security Info=True;Pipelining=False;Server Redirection Mode=Required;" + "Server RSA Public Key File=rsa.pem;Server SPN=mariadb/host.example.com@EXAMPLE.COM;Treat Tiny As Boolean=False;" + - "Use Affected Rows=True;Use Compression=True;Use XA Transactions=False", + "Use Affected Rows=True;Use Compression=True;Use XA Transactions=False;Enable Extended Data Types=True", csb.ConnectionString.Replace("Protocol=NamedPipe", "Protocol=Pipe")); #endif }