diff --git a/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java b/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java index bfa2a299bc..b2e4a2cbc2 100644 --- a/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java +++ b/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java @@ -94,7 +94,13 @@ public static InternalArray copyArray(InternalArray from, DataType eleType) { return new GenericArray(newArray); } - private static InternalMap copyMap(InternalMap map, DataType keyType, DataType valueType) { + /** + * Creates a copy of the given {@link InternalMap}. + * + *

This method is intended for internal use by the Fluss Spark adapter and is not part of the + * stable public API. Its signature and behavior may change without notice across releases. + */ + public static InternalMap copyMap(InternalMap map, DataType keyType, DataType valueType) { if (map instanceof BinaryMap) { return ((BinaryMap) map).copy(); } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/DataConverter.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/DataConverter.scala index cb9b309003..7d27771ce2 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/DataConverter.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/DataConverter.scala @@ -76,8 +76,7 @@ object DataConverter { } def toSparkMap(flussMap: FlussInternalMap, mapType: FlussMapType): SparkMapData = { - // TODO: support map type in fluss-spark - throw new UnsupportedOperationException() + new FlussAsSparkMap(mapType).replace(flussMap) } def toSparkInternalRow(flussRow: FlussInternalRow, rowType: RowType): SparkInteralRow = { diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkArray.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkArray.scala index c67ae1e7ad..8a00b8ec91 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkArray.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkArray.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData import org.apache.spark.sql.types.{DataType => SparkDataType, Decimal => SparkDecimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +/** Wraps a Fluss [[FlussInternalArray]] as a Spark [[SparkArrayData]]. */ class FlussAsSparkArray(elementType: FlussDataType) extends SparkArrayData { var flussArray: FlussInternalArray = _ diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkMap.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkMap.scala new file mode 100644 index 0000000000..0c04583ae4 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkMap.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.row + +import org.apache.fluss.row.{InternalMap => FlussInternalMap} +import org.apache.fluss.types.{MapType => FlussMapType} +import org.apache.fluss.utils.InternalRowUtils + +import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData => SparkMapData} + +/** Wraps a Fluss [[FlussInternalMap]] as a Spark [[SparkMapData]]. */ +class FlussAsSparkMap(mapType: FlussMapType) extends SparkMapData { + + var flussMap: FlussInternalMap = _ + + def replace(map: FlussInternalMap): SparkMapData = { + this.flussMap = map + this + } + + override def numElements(): Int = flussMap.size() + + override def copy(): SparkMapData = { + new FlussAsSparkMap(mapType) + .replace(InternalRowUtils.copyMap(flussMap, mapType.getKeyType, mapType.getValueType)) + } + + override def keyArray(): SparkArrayData = { + val keyType = mapType.getKeyType + new FlussAsSparkArray(keyType).replace(flussMap.keyArray()) + } + + override def valueArray(): SparkArrayData = { + val valueType = mapType.getValueType + new FlussAsSparkArray(valueType).replace(flussMap.valueArray()) + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkRow.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkRow.scala index 175900cd12..4744264704 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkRow.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkRow.scala @@ -18,7 +18,7 @@ package org.apache.fluss.spark.row import org.apache.fluss.row.{InternalRow => FlussInternalRow} -import org.apache.fluss.types.{ArrayType => FlussArrayType, BinaryType => FlussBinaryType, LocalZonedTimestampType, RowType, TimestampType} +import org.apache.fluss.types.{ArrayType => FlussArrayType, BinaryType => FlussBinaryType, LocalZonedTimestampType, MapType => FlussMapType, RowType, TimestampType} import org.apache.fluss.utils.InternalRowUtils import org.apache.spark.sql.catalyst.{InternalRow => SparkInteralRow} @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData import org.apache.spark.sql.types.{DataType => SparkDataType, Decimal => SparkDecimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +/** Wraps a Fluss [[FlussInternalRow]] as a Spark [[SparkInteralRow]]. */ class FlussAsSparkRow(rowType: RowType) extends SparkInteralRow { val fieldCount: Int = rowType.getFieldCount @@ -104,8 +105,9 @@ class FlussAsSparkRow(rowType: RowType) extends SparkInteralRow { } override def getMap(ordinal: Int): SparkMapData = { - // TODO: support map type in fluss-spark - throw new UnsupportedOperationException() + val mapType = rowType.getTypeAt(ordinal).asInstanceOf[FlussMapType] + val flussMap = row.getMap(ordinal) + DataConverter.toSparkMap(flussMap, mapType) } override def get(ordinal: Int, dataType: SparkDataType): AnyRef = { diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussArray.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussArray.scala index 17d45ec028..277f3a295c 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussArray.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussArray.scala @@ -20,7 +20,7 @@ package org.apache.fluss.spark.row import org.apache.fluss.row.{BinaryString, Decimal, InternalArray => FlussInternalArray, InternalMap, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz} import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData} -import org.apache.spark.sql.types.{ArrayType => SparkArrayType, DataType => SparkDataType, StructType} +import org.apache.spark.sql.types.{ArrayType => SparkArrayType, DataType => SparkDataType, MapType => SparkMapType, StructType} /** Wraps a Spark [[SparkArrayData]] as a Fluss [[FlussInternalArray]]. */ class SparkAsFlussArray(arrayData: SparkArrayData, elementType: SparkDataType) @@ -125,13 +125,14 @@ class SparkAsFlussArray(arrayData: SparkArrayData, elementType: SparkDataType) arrayData.getArray(pos), elementType.asInstanceOf[SparkArrayType].elementType) + /** Returns the map value at the given position. */ + override def getMap(pos: Int): InternalMap = { + val mapType = elementType.asInstanceOf[SparkMapType] + SparkAsFlussMap(arrayData.getMap(pos), mapType) + } + /** Returns the row value at the given position. */ override def getRow(pos: Int, numFields: Int): FlussInternalRow = new SparkAsFlussRow(elementType.asInstanceOf[StructType]) .replace(arrayData.getStruct(pos, numFields)) - - /** Returns the map value at the given position. */ - override def getMap(pos: Int): InternalMap = { - throw new UnsupportedOperationException() - } } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussMap.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussMap.scala new file mode 100644 index 0000000000..bd04300207 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussMap.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.row + +import org.apache.fluss.row.{InternalArray => FlussInternalArray, InternalMap => FlussInternalMap} + +import org.apache.spark.sql.catalyst.util.{MapData => SparkMapData} +import org.apache.spark.sql.types.{DataType => SparkDataType, MapType => SparkMapType} + +/** Wraps a Spark [[SparkMapData]] as a Fluss [[FlussInternalMap]]. */ +class SparkAsFlussMap(mapData: SparkMapData, keyType: SparkDataType, valueType: SparkDataType) + extends FlussInternalMap + with Serializable { + + /** Returns the number of key-value mappings in this map. */ + override def size(): Int = mapData.numElements() + + /** + * Returns an array view of the keys contained in this map. + * + *

A key-value pair has the same index in the key array and value array. + */ + override def keyArray(): FlussInternalArray = { + new SparkAsFlussArray(mapData.keyArray(), keyType) + } + + /** + * Returns an array view of the values contained in this map. + * + *

A key-value pair has the same index in the key array and value array. + */ + override def valueArray(): FlussInternalArray = { + new SparkAsFlussArray(mapData.valueArray(), valueType) + } +} + +object SparkAsFlussMap { + def apply(mapData: SparkMapData, mapType: SparkMapType): SparkAsFlussMap = + new SparkAsFlussMap(mapData, mapType.keyType, mapType.valueType) +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussRow.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussRow.scala index 3a5c9613c2..368a7c973d 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussRow.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussRow.scala @@ -20,7 +20,7 @@ package org.apache.fluss.spark.row import org.apache.fluss.row.{BinaryString, Decimal, InternalMap, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz} import org.apache.spark.sql.catalyst.{InternalRow => SparkInternalRow} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{MapType => SparkMapType, StructType} /** Wraps a Spark [[SparkInternalRow]] as a Fluss [[FlussInternalRow]]. */ class SparkAsFlussRow(schema: StructType) extends FlussInternalRow with Serializable { @@ -127,6 +127,8 @@ class SparkAsFlussRow(schema: StructType) extends FlussInternalRow with Serializ /** Returns the map value at the given position. */ override def getMap(pos: Int): InternalMap = { - throw new UnsupportedOperationException() + val sparkMapData = row.getMap(pos) + val mapType = schema.fields(pos).dataType.asInstanceOf[SparkMapType] + SparkAsFlussMap(sparkMapData, mapType) } } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala index 5f9613cbc3..2342d658c2 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala @@ -213,18 +213,18 @@ class SparkLogTableReadTest extends FlussSparkTestBase { test("Spark Read: nested data types table") { withTable("t") { - // TODO: support map type sql(s""" |CREATE TABLE $DEFAULT_DATABASE.t ( |id INT, |arr ARRAY, + |map MAP, |struct_col STRUCT |)""".stripMargin) sql(s""" |INSERT INTO $DEFAULT_DATABASE.t VALUES - |(1, ARRAY(1, 2, 3), STRUCT(100, 'nested_value')), - |(2, ARRAY(7, 8, 9), STRUCT(200, 'nested_value2')) + |(1, ARRAY(1, 2, 3), MAP("k1", 111, "k2", 222), STRUCT(100, 'nested_value')), + |(2, ARRAY(7, 8, 9), MAP("k1", 333, "k2", 444), STRUCT(200, 'nested_value2')) |""".stripMargin) checkAnswer( @@ -232,10 +232,12 @@ class SparkLogTableReadTest extends FlussSparkTestBase { Row( 1, Seq(1, 2, 3), + Map("k1" -> 111, "k2" -> 222), Row(100, "nested_value") ) :: Row( 2, Seq(7, 8, 9), + Map("k1" -> 333, "k2" -> 444), Row(200, "nested_value2") ) :: Nil ) diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkWriteTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkWriteTest.scala index 179c05e801..a2382c1616 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkWriteTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkWriteTest.scala @@ -47,7 +47,9 @@ class SparkWriteTest extends FlussSparkTestBase { | 1234567.89, 12345678900987654321.12, | "test", | TO_TIMESTAMP('2025-12-31 10:00:00', 'yyyy-MM-dd kk:mm:ss'), - | array(11.11F, 22.22F), struct(123L, "apache fluss") + | array(11.11F, 22.22F), + | map("k1", 111, "k2", 222), + | struct(123L, "apache fluss") |) |""".stripMargin) @@ -56,7 +58,7 @@ class SparkWriteTest extends FlussSparkTestBase { assertThat(rows.length).isEqualTo(1) val row = rows.head - assertThat(row.getFieldCount).isEqualTo(13) + assertThat(row.getFieldCount).isEqualTo(14) assertThat(row.getBoolean(0)).isEqualTo(true) assertThat(row.getByte(1)).isEqualTo(1.toByte) assertThat(row.getShort(2)).isEqualTo(10.toShort) @@ -71,7 +73,12 @@ class SparkWriteTest extends FlussSparkTestBase { assertThat(row.getTimestampLtz(10, 6).toInstant) .isEqualTo(Timestamp.valueOf("2025-12-31 10:00:00.0").toInstant) assertThat(row.getArray(11).toFloatArray).containsExactly(Array(11.11f, 22.22f): _*) - val nestedRow = row.getRow(12, 2) + val mapData = row.getMap(12) + assertThat(mapData.size()).isEqualTo(2) + assertThat(mapData.keyArray().getString(0).toString).isEqualTo("k1") + assertThat(mapData.keyArray().getString(1).toString).isEqualTo("k2") + assertThat(mapData.valueArray().toIntArray).containsExactly(Array(111, 222): _*) + val nestedRow = row.getRow(13, 2) assertThat(nestedRow.getFieldCount).isEqualTo(2) assertThat(nestedRow.getLong(0)).isEqualTo(123L) assertThat(nestedRow.getString(1).toString).isEqualTo("apache fluss") diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/DataConverterTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/DataConverterTest.scala index 6866ba7b60..d7cfc9d2f9 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/DataConverterTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/DataConverterTest.scala @@ -17,14 +17,16 @@ package org.apache.fluss.spark.row -import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericRow, TimestampLtz, TimestampNtz} -import org.apache.fluss.types.{ArrayType, CharType, DataTypes, DecimalType, LocalZonedTimestampType, RowType, TimestampType} +import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericMap, GenericRow, TimestampLtz, TimestampNtz} +import org.apache.fluss.types.{ArrayType, CharType, DataTypes, DecimalType, LocalZonedTimestampType, MapType, RowType, TimestampType} import org.apache.spark.sql.types.{Decimal => SparkDecimal} import org.apache.spark.unsafe.types.UTF8String import org.assertj.core.api.Assertions.assertThat import org.scalatest.funsuite.AnyFunSuite +import scala.collection.JavaConverters._ + class DataConverterTest extends AnyFunSuite { test("toSparkObject: null value") { @@ -274,6 +276,28 @@ class DataConverterTest extends AnyFunSuite { assertThat(result.asInstanceOf[Long]).isEqualTo(2000000000L) // microseconds } + test("toSparkMap: Map type") { + val flussMap = new GenericMap( + Map( + BinaryString.fromString("a") -> Integer.valueOf(1), + BinaryString.fromString("b") -> Integer.valueOf(2)).asJava) + val mapType = new MapType(DataTypes.STRING, DataTypes.INT) + val sparkMap = DataConverter.toSparkMap(flussMap, mapType) + assertThat(sparkMap.numElements()).isEqualTo(2) + + val keyArray = sparkMap.keyArray() + val valueArray = sparkMap.valueArray() + assertThat(keyArray.numElements()).isEqualTo(2) + assertThat(valueArray.numElements()).isEqualTo(2) + + val actualMap = + (0 until sparkMap.numElements()) + .map(i => keyArray.getUTF8String(i).toString -> valueArray.getInt(i)) + .toMap + + assertThat(actualMap).isEqualTo(Map("a" -> 1, "b" -> 2)) + } + test("toSparkObject: ROW type") { val rowType = RowType .builder() @@ -288,10 +312,4 @@ class DataConverterTest extends AnyFunSuite { assertThat(result).isNotNull() assertThat(result.asInstanceOf[FlussAsSparkRow].getInt(0)).isEqualTo(42) } - - test("toSparkMap: unsupported") { - assertThrows[UnsupportedOperationException] { - DataConverter.toSparkMap(null, null) - } - } } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkArrayTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkArrayTest.scala index 029104e5bf..27a4bd529e 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkArrayTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkArrayTest.scala @@ -308,14 +308,17 @@ class FlussAsSparkArrayTest extends AnyFunSuite { assertThat(sparkInnerArray2.getInt(2)).isEqualTo(6) } - test("getMap: unsupported operation") { + test("getMap: read map array") { val mapType = DataTypes.MAP(DataTypes.INT, DataTypes.STRING) - val flussArray = GenericArray.of(new GenericMap(Map(1 -> "map").asJava)) + val innerMap = + new GenericMap(Map(Integer.valueOf(1) -> BinaryString.fromString("value1")).asJava) + val flussArray = new GenericArray(Array[Object](innerMap)) val sparkArray = new FlussAsSparkArray(mapType).replace(flussArray) - assertThrows[UnsupportedOperationException] { - sparkArray.getMap(0) - } + val sparkMap = sparkArray.getMap(0) + assertThat(sparkMap.numElements()).isEqualTo(1) + assertThat(sparkMap.keyArray().getInt(0)).isEqualTo(1) + assertThat(sparkMap.valueArray().getUTF8String(0).toString).isEqualTo("value1") } test("getInterval: unsupported operation") { diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkMapTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkMapTest.scala new file mode 100644 index 0000000000..91d492dbda --- /dev/null +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkMapTest.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.row + +import org.apache.fluss.row.{BinaryString, GenericArray, GenericMap, GenericRow} +import org.apache.fluss.spark.FlussSparkTestBase +import org.apache.fluss.types.{ArrayType, BigIntType, CharType, IntType, MapType, RowType, StringType} + +import org.assertj.core.api.Assertions.assertThat + +import scala.collection.JavaConverters._ + +class FlussAsSparkMapTest extends FlussSparkTestBase { + + test("numElements: empty map") { + val flussMap = new GenericMap(Map.empty[Object, Object].asJava) + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType())) + .replace(flussMap) + + assertThat(sparkMap.numElements()).isEqualTo(0) + } + + test("numElements: non-empty map") { + val flussMap = createSimpleMap() + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType())) + .replace(flussMap) + + assertThat(sparkMap.numElements()).isEqualTo(3) + } + + test("keyArray: empty map") { + val flussMap = new GenericMap(Map.empty[Object, Object].asJava) + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType())) + .replace(flussMap) + + val sparkKeyArray = sparkMap.keyArray() + assertThat(sparkKeyArray.numElements()).isEqualTo(0) + } + + test("keyArray: non-empty map") { + val flussMap = createSimpleMap() + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType())) + .replace(flussMap) + + val sparkKeyArray = sparkMap.keyArray() + assertThat(sparkKeyArray.numElements()).isEqualTo(3) + // Keys are in insertion order for GenericMap + val key0 = sparkKeyArray.getUTF8String(0).toString + val key1 = sparkKeyArray.getUTF8String(1).toString + val key2 = sparkKeyArray.getUTF8String(2).toString + assertThat(Set(key0, key1, key2)).isEqualTo(Set("key1", "key2", "key3")) + } + + test("valueArray: empty map") { + val flussMap = new GenericMap(Map.empty[Object, Object].asJava) + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType())) + .replace(flussMap) + + val sparkValueArray = sparkMap.valueArray() + assertThat(sparkValueArray.numElements()).isEqualTo(0) + } + + test("valueArray: non-empty map") { + val flussMap = createSimpleMap() + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType())) + .replace(flussMap) + + val sparkValueArray = sparkMap.valueArray() + assertThat(sparkValueArray.numElements()).isEqualTo(3) + assertThat(sparkValueArray.getInt(0)).isEqualTo(100) + assertThat(sparkValueArray.getInt(1)).isEqualTo(200) + assertThat(sparkValueArray.getInt(2)).isEqualTo(300) + } + + test("copy: creates deep copy") { + val flussMap = createSimpleMap() + val originalSparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType())) + .replace(flussMap) + val copiedSparkMap = originalSparkMap.copy() + + assertThat(copiedSparkMap.numElements()).isEqualTo(3) + } + + test("integration: map with nested array") { + val flussMap = createMapWithNestedArrays() + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new ArrayType(new IntType()))) + .replace(flussMap) + + assertThat(sparkMap.numElements()).isEqualTo(2) + + val sparkKeyArray = sparkMap.keyArray() + val key0 = sparkKeyArray.getUTF8String(0).toString + val key1 = sparkKeyArray.getUTF8String(1).toString + assertThat(Set(key0, key1)).isEqualTo(Set("arr1", "arr2")) + + val sparkValueArray = sparkMap.valueArray() + // Check that we have 2 arrays + assertThat(sparkValueArray.numElements()).isEqualTo(2) + } + + test("integration: map with nested row") { + val flussMap = createMapWithNestedRows() + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), createSimpleRowType())) + .replace(flussMap) + + assertThat(sparkMap.numElements()).isEqualTo(2) + + val sparkKeyArray = sparkMap.keyArray() + val key0 = sparkKeyArray.getUTF8String(0).toString + val key1 = sparkKeyArray.getUTF8String(1).toString + assertThat(Set(key0, key1)).isEqualTo(Set("row1", "row2")) + } + + test("integration: map with nested map") { + val flussMap = createMapWithNestedMaps() + val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), createSimpleMapType())) + .replace(flussMap) + + assertThat(sparkMap.numElements()).isEqualTo(2) + + val sparkKeyArray = sparkMap.keyArray() + val key0 = sparkKeyArray.getUTF8String(0).toString + val key1 = sparkKeyArray.getUTF8String(1).toString + assertThat(Set(key0, key1)).isEqualTo(Set("map1", "map2")) + + val sparkValueArray = sparkMap.valueArray() + assertThat(sparkValueArray.numElements()).isEqualTo(2) + } + + private def createSimpleMap(): GenericMap = { + new GenericMap( + Map( + BinaryString.fromString("key1") -> Integer.valueOf(100), + BinaryString.fromString("key2") -> Integer.valueOf(200), + BinaryString.fromString("key3") -> Integer.valueOf(300) + ).asJava + ) + } + + private def createMapWithNestedArrays(): GenericMap = { + new GenericMap( + Map( + BinaryString.fromString("arr1") -> new GenericArray( + Array[Object](Integer.valueOf(1), Integer.valueOf(2), Integer.valueOf(3))), + BinaryString.fromString("arr2") -> new GenericArray( + Array[Object](Integer.valueOf(4), Integer.valueOf(5), Integer.valueOf(6))) + ).asJava + ) + } + + private def createMapWithNestedRows(): GenericMap = { + val row1 = new GenericRow(2) + row1.setField(0, java.lang.Long.valueOf(100L)) + row1.setField(1, BinaryString.fromString("value1")) + + val row2 = new GenericRow(2) + row2.setField(0, java.lang.Long.valueOf(200L)) + row2.setField(1, BinaryString.fromString("value2")) + + new GenericMap( + Map( + BinaryString.fromString("row1") -> row1, + BinaryString.fromString("row2") -> row2 + ).asJava + ) + } + + private def createMapWithNestedMaps(): GenericMap = { + val innerMap1 = new GenericMap( + Map( + BinaryString.fromString("inner1") -> Integer.valueOf(10), + BinaryString.fromString("inner2") -> Integer.valueOf(20) + ).asJava + ) + + val innerMap2 = new GenericMap( + Map( + BinaryString.fromString("inner3") -> Integer.valueOf(30), + BinaryString.fromString("inner4") -> Integer.valueOf(40) + ).asJava + ) + + new GenericMap( + Map( + BinaryString.fromString("map1") -> innerMap1, + BinaryString.fromString("map2") -> innerMap2 + ).asJava + ) + } + + private def createSimpleRowType(): RowType = { + val fields = java.util.Arrays.asList( + new org.apache.fluss.types.DataField("field1", new BigIntType()), + new org.apache.fluss.types.DataField("field2", new StringType()) + ) + new RowType(fields) + } + + private def createSimpleMapType(): MapType = { + new MapType(new CharType(10), new IntType()) + } +} diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkRowTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkRowTest.scala index 03ccad4ef6..697fbd2f4d 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkRowTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkRowTest.scala @@ -17,12 +17,14 @@ package org.apache.fluss.spark.row -import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericRow, TimestampLtz, TimestampNtz} +import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericMap, GenericRow, TimestampLtz, TimestampNtz} import org.apache.fluss.types.{ArrayType, BinaryType, DataTypes, LocalZonedTimestampType, RowType, TimestampType} import org.assertj.core.api.Assertions.assertThat import org.scalatest.funsuite.AnyFunSuite +import scala.collection.JavaConverters._ + class FlussAsSparkRowTest extends AnyFunSuite { test("basic row operations: numFields and fieldCount") { @@ -339,20 +341,24 @@ class FlussAsSparkRowTest extends AnyFunSuite { assertThat(sparkArray.getInt(4)).isEqualTo(5) } - test("getMap: unsupported operation") { + test("getMap: read map field") { + val mapType = DataTypes.MAP(DataTypes.INT, DataTypes.STRING) val rowType = RowType .builder() - .field("dummy", DataTypes.INT) + .field("map_col", mapType) .build() + val flussMap = + new GenericMap(Map(Integer.valueOf(1) -> BinaryString.fromString("value1")).asJava) val flussRow = new GenericRow(1) - flussRow.setField(0, Integer.valueOf(1)) + flussRow.setField(0, flussMap) val sparkRow = new FlussAsSparkRow(rowType).replace(flussRow) - assertThrows[UnsupportedOperationException] { - sparkRow.getMap(0) - } + val sparkMap = sparkRow.getMap(0) + assertThat(sparkMap.numElements()).isEqualTo(1) + assertThat(sparkMap.keyArray().getInt(0)).isEqualTo(1) + assertThat(sparkMap.valueArray().getUTF8String(0).toString).isEqualTo("value1") } test("getInterval: unsupported operation") { diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussArrayTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussArrayTest.scala index fb4ee6308a..74f07a58d2 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussArrayTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussArrayTest.scala @@ -20,7 +20,7 @@ package org.apache.fluss.spark.row import org.apache.fluss.spark.FlussSparkTestBase import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{BooleanType, ByteType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.assertj.core.api.Assertions.assertThat @@ -129,4 +129,42 @@ class SparkAsFlussArrayTest extends FlussSparkTestBase { } } } + + test("Fluss SparkAsFlussArray: Map Type") { + val mapType = org.apache.spark.sql.types.MapType(StringType, IntegerType) + val data = Array( + ArrayBasedMapData.apply( + Array(UTF8String.fromString("key1"), UTF8String.fromString("key2")) + .asInstanceOf[Array[Any]], + Array(100, 200)), + ArrayBasedMapData.apply( + Array(UTF8String.fromString("key3"), UTF8String.fromString("key4")) + .asInstanceOf[Array[Any]], + Array(300, 400)), + null + ) + val sparkArrayData = new GenericArrayData(data) + val flussArray = new SparkAsFlussArray(sparkArrayData, mapType) + + assertThat(flussArray.size()).isEqualTo(3) + assertThat(flussArray.isNullAt(0)).isFalse() + assertThat(flussArray.isNullAt(1)).isFalse() + assertThat(flussArray.isNullAt(2)).isTrue() + + // Check first map + val map1 = flussArray.getMap(0) + assertThat(map1.size()).isEqualTo(2) + assertThat(map1.keyArray().getString(0).toString).isEqualTo("key1") + assertThat(map1.keyArray().getString(1).toString).isEqualTo("key2") + assertThat(map1.valueArray().getInt(0)).isEqualTo(100) + assertThat(map1.valueArray().getInt(1)).isEqualTo(200) + + // Check second map + val map2 = flussArray.getMap(1) + assertThat(map2.size()).isEqualTo(2) + assertThat(map2.keyArray().getString(0).toString).isEqualTo("key3") + assertThat(map2.keyArray().getString(1).toString).isEqualTo("key4") + assertThat(map2.valueArray().getInt(0)).isEqualTo(300) + assertThat(map2.valueArray().getInt(1)).isEqualTo(400) + } } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussMapTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussMapTest.scala new file mode 100644 index 0000000000..383ccf0e54 --- /dev/null +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussMapTest.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.row + +import org.apache.fluss.spark.FlussSparkTestBase + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData} +import org.apache.spark.sql.types.{IntegerType, MapType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String +import org.assertj.core.api.Assertions.assertThat + +class SparkAsFlussMapTest extends FlussSparkTestBase { + + // Helper method to convert String array to UTF8String array + private def toUTF8Strings(strings: String*): Array[Any] = { + strings.map(UTF8String.fromString).toArray + } + + test("size: empty map") { + val sparkMap = new ArrayBasedMapData( + ArrayData.toArrayData(Array.empty[Any]), + ArrayData.toArrayData(Array.empty[Any])) + val mapType = MapType(StringType, IntegerType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(0) + } + + test("size: non-empty map") { + val keys = toUTF8Strings("key1", "key2", "key3") + val values = Array(100, 200, 300) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, IntegerType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(3) + } + + test("keyArray: empty map") { + val sparkMap = new ArrayBasedMapData( + ArrayData.toArrayData(Array.empty[Any]), + ArrayData.toArrayData(Array.empty[Any])) + val mapType = MapType(StringType, IntegerType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + val keyArray = flussMap.keyArray() + assertThat(keyArray.size()).isEqualTo(0) + } + + test("keyArray: non-empty map") { + val keys = toUTF8Strings("key1", "key2", "key3") + val values = Array(100, 200, 300) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, IntegerType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + val keyArray = flussMap.keyArray() + assertThat(keyArray.size()).isEqualTo(3) + assertThat(keyArray.getString(0).toString).isEqualTo("key1") + assertThat(keyArray.getString(1).toString).isEqualTo("key2") + assertThat(keyArray.getString(2).toString).isEqualTo("key3") + } + + test("valueArray: empty map") { + val sparkMap = new ArrayBasedMapData( + ArrayData.toArrayData(Array.empty[Any]), + ArrayData.toArrayData(Array.empty[Any])) + val mapType = MapType(StringType, IntegerType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + val valueArray = flussMap.valueArray() + assertThat(valueArray.size()).isEqualTo(0) + } + + test("valueArray: non-empty map") { + val keys = toUTF8Strings("key1", "key2", "key3") + val values = Array(100, 200, 300) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, IntegerType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + val valueArray = flussMap.valueArray() + assertThat(valueArray.size()).isEqualTo(3) + assertThat(valueArray.getInt(0)).isEqualTo(100) + assertThat(valueArray.getInt(1)).isEqualTo(200) + assertThat(valueArray.getInt(2)).isEqualTo(300) + } + + test("integration: map with nested array") { + val keys = toUTF8Strings("array1", "array2") + val values = Array( + ArrayData.toArrayData(Array(1, 2, 3)), + ArrayData.toArrayData(Array(4, 5, 6)) + ) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.ArrayType(IntegerType)) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + + val keyArray = flussMap.keyArray() + assertThat(keyArray.getString(0).toString).isEqualTo("array1") + assertThat(keyArray.getString(1).toString).isEqualTo("array2") + + val valueArray = flussMap.valueArray() + val array1 = valueArray.getArray(0) + assertThat(array1.size()).isEqualTo(3) + assertThat(array1.getInt(0)).isEqualTo(1) + assertThat(array1.getInt(1)).isEqualTo(2) + assertThat(array1.getInt(2)).isEqualTo(3) + + val array2 = valueArray.getArray(1) + assertThat(array2.size()).isEqualTo(3) + assertThat(array2.getInt(0)).isEqualTo(4) + assertThat(array2.getInt(1)).isEqualTo(5) + assertThat(array2.getInt(2)).isEqualTo(6) + } + + test("integration: map with nested row") { + val keys = toUTF8Strings("row1", "row2") + val values = Array( + InternalRow.apply(UTF8String.fromString("value1"), 100), + InternalRow.apply(UTF8String.fromString("value2"), 200) + ) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val structType = new StructType().add("name", StringType).add("value", IntegerType) + val mapType = MapType(StringType, structType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + + val keyArray = flussMap.keyArray() + assertThat(keyArray.getString(0).toString).isEqualTo("row1") + assertThat(keyArray.getString(1).toString).isEqualTo("row2") + + val valueArray = flussMap.valueArray() + val row1 = valueArray.getRow(0, 2) + assertThat(row1.getString(0).toString).isEqualTo("value1") + assertThat(row1.getInt(1)).isEqualTo(100) + + val row2 = valueArray.getRow(1, 2) + assertThat(row2.getString(0).toString).isEqualTo("value2") + assertThat(row2.getInt(1)).isEqualTo(200) + } + + test("integration: map with nested map") { + val keys = toUTF8Strings("map1", "map2") + val values = Array( + new ArrayBasedMapData( + ArrayData.toArrayData(toUTF8Strings("inner1", "inner2")), + ArrayData.toArrayData(Array(10, 20))), + new ArrayBasedMapData( + ArrayData.toArrayData(toUTF8Strings("inner3", "inner4")), + ArrayData.toArrayData(Array(30, 40))) + ) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val innerMapType = MapType(StringType, IntegerType) + val mapType = MapType(StringType, innerMapType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + + val keyArray = flussMap.keyArray() + assertThat(keyArray.getString(0).toString).isEqualTo("map1") + assertThat(keyArray.getString(1).toString).isEqualTo("map2") + + val valueArray = flussMap.valueArray() + val innerMap1 = valueArray.getMap(0) + assertThat(innerMap1.size()).isEqualTo(2) + val innerMap1Keys = innerMap1.keyArray() + assertThat(innerMap1Keys.getString(0).toString).isEqualTo("inner1") + assertThat(innerMap1Keys.getString(1).toString).isEqualTo("inner2") + val innerMap1Values = innerMap1.valueArray() + assertThat(innerMap1Values.getInt(0)).isEqualTo(10) + assertThat(innerMap1Values.getInt(1)).isEqualTo(20) + + val innerMap2 = valueArray.getMap(1) + assertThat(innerMap2.size()).isEqualTo(2) + val innerMap2Keys = innerMap2.keyArray() + assertThat(innerMap2Keys.getString(0).toString).isEqualTo("inner3") + assertThat(innerMap2Keys.getString(1).toString).isEqualTo("inner4") + val innerMap2Values = innerMap2.valueArray() + assertThat(innerMap2Values.getInt(0)).isEqualTo(30) + assertThat(innerMap2Values.getInt(1)).isEqualTo(40) + } + + test("basic accessors return expected keys and values") { + val keys = toUTF8Strings("key1", "key2") + val values = Array(100, 200) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, IntegerType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + // Verify the wrapped map exposes the expected size, keys, and values. + assertThat(flussMap.size()).isEqualTo(2) + assertThat(flussMap.keyArray().getString(0).toString).isEqualTo("key1") + assertThat(flussMap.valueArray().getInt(1)).isEqualTo(200) + } + + test("map with integer values") { + val keys = toUTF8Strings("int_key1", "int_key2") + val values = Array(100, 200) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, IntegerType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val valueArray = flussMap.valueArray() + assertThat(valueArray.getInt(0)).isEqualTo(100) + assertThat(valueArray.getInt(1)).isEqualTo(200) + } + + test("map with float values") { + val keys = toUTF8Strings("float_key1", "float_key2") + val values = Array(12.34f, 56.78f) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.FloatType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val valueArray = flussMap.valueArray() + assertThat(valueArray.getFloat(0)).isEqualTo(12.34f) + assertThat(valueArray.getFloat(1)).isEqualTo(56.78f) + } + + test("map with double values") { + val keys = toUTF8Strings("double_key1", "double_key2") + val values = Array(56.78, 90.12) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.DoubleType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val valueArray = flussMap.valueArray() + assertThat(valueArray.getDouble(0)).isEqualTo(56.78) + assertThat(valueArray.getDouble(1)).isEqualTo(90.12) + } + + test("map with long values") { + val keys = toUTF8Strings("long_key1", "long_key2") + val values = Array(1000L, 2000L) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.LongType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val valueArray = flussMap.valueArray() + assertThat(valueArray.getLong(0)).isEqualTo(1000L) + assertThat(valueArray.getLong(1)).isEqualTo(2000L) + } + + test("map with boolean values") { + val keys = toUTF8Strings("bool_key1", "bool_key2") + val values = Array(true, false) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.BooleanType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val valueArray = flussMap.valueArray() + assertThat(valueArray.getBoolean(0)).isTrue() + assertThat(valueArray.getBoolean(1)).isFalse() + } + + test("map with byte values") { + val keys = toUTF8Strings("byte_key1", "byte_key2") + val values = Array(127.toByte, 64.toByte) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.ByteType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val valueArray = flussMap.valueArray() + assertThat(valueArray.getByte(0)).isEqualTo(127.toByte) + assertThat(valueArray.getByte(1)).isEqualTo(64.toByte) + } + + test("map with short values") { + val keys = toUTF8Strings("short_key1", "short_key2") + val values = Array(1000.toShort, 2000.toShort) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.ShortType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val valueArray = flussMap.valueArray() + assertThat(valueArray.getShort(0)).isEqualTo(1000.toShort) + assertThat(valueArray.getShort(1)).isEqualTo(2000.toShort) + } + + test("map with null values") { + val keys = toUTF8Strings("key1", "key2", "key3") + val values = Array[Any](UTF8String.fromString("value1"), null, UTF8String.fromString("value3")) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, StringType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(3) + val valueArray = flussMap.valueArray() + assertThat(valueArray.isNullAt(0)).isFalse() + assertThat(valueArray.isNullAt(1)).isTrue() + assertThat(valueArray.isNullAt(2)).isFalse() + assertThat(valueArray.getString(0).toString).isEqualTo("value1") + assertThat(valueArray.getString(2).toString).isEqualTo("value3") + } + + test("map with numeric keys") { + val keys = Array(1, 2, 3) + val values = toUTF8Strings("value1", "value2", "value3") + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(IntegerType, StringType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(3) + val keyArray = flussMap.keyArray() + val valueArray = flussMap.valueArray() + assertThat(keyArray.getInt(0)).isEqualTo(1) + assertThat(keyArray.getInt(1)).isEqualTo(2) + assertThat(keyArray.getInt(2)).isEqualTo(3) + assertThat(valueArray.getString(0).toString).isEqualTo("value1") + assertThat(valueArray.getString(1).toString).isEqualTo("value2") + assertThat(valueArray.getString(2).toString).isEqualTo("value3") + } + + test("map with complex nested structure: array of rows") { + val keys = toUTF8Strings("data1", "data2") + val values = Array( + ArrayData.toArrayData( + Array( + InternalRow.apply(UTF8String.fromString("name1"), 100), + InternalRow.apply(UTF8String.fromString("name2"), 200) + )), + ArrayData.toArrayData( + Array( + InternalRow.apply(UTF8String.fromString("name3"), 300), + InternalRow.apply(UTF8String.fromString("name4"), 400) + )) + ) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val structType = new StructType().add("name", StringType).add("value", IntegerType) + val arrayType = org.apache.spark.sql.types.ArrayType(structType) + val mapType = MapType(StringType, arrayType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val keyArray = flussMap.keyArray() + assertThat(keyArray.getString(0).toString).isEqualTo("data1") + assertThat(keyArray.getString(1).toString).isEqualTo("data2") + + val valueArray = flussMap.valueArray() + val array1 = valueArray.getArray(0) + assertThat(array1.size()).isEqualTo(2) + val row1_0 = array1.getRow(0, 2) + assertThat(row1_0.getString(0).toString).isEqualTo("name1") + assertThat(row1_0.getInt(1)).isEqualTo(100) + val row1_1 = array1.getRow(1, 2) + assertThat(row1_1.getString(0).toString).isEqualTo("name2") + assertThat(row1_1.getInt(1)).isEqualTo(200) + + val array2 = valueArray.getArray(1) + assertThat(array2.size()).isEqualTo(2) + val row2_0 = array2.getRow(0, 2) + assertThat(row2_0.getString(0).toString).isEqualTo("name3") + assertThat(row2_0.getInt(1)).isEqualTo(300) + val row2_1 = array2.getRow(1, 2) + assertThat(row2_1.getString(0).toString).isEqualTo("name4") + assertThat(row2_1.getInt(1)).isEqualTo(400) + } + + test("map with decimal values") { + val keys = toUTF8Strings("dec1", "dec2") + val values = Array( + org.apache.spark.sql.types.Decimal(123.45), + org.apache.spark.sql.types.Decimal(678.90) + ) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.DecimalType(10, 2)) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val keyArray = flussMap.keyArray() + assertThat(keyArray.getString(0).toString).isEqualTo("dec1") + assertThat(keyArray.getString(1).toString).isEqualTo("dec2") + + val valueArray = flussMap.valueArray() + val dec1 = valueArray.getDecimal(0, 10, 2) + assertThat(dec1.toBigDecimal.compareTo(new java.math.BigDecimal("123.45"))).isZero() + val dec2 = valueArray.getDecimal(1, 10, 2) + assertThat(dec2.toBigDecimal.compareTo(new java.math.BigDecimal("678.90"))).isZero() + } + + test("map with timestamp values") { + val keys = toUTF8Strings("ts1", "ts2") + val values = Array( + 1634567890123456L, // microseconds timestamp + 1634567891123456L + ) + val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values)) + val mapType = MapType(StringType, org.apache.spark.sql.types.TimestampType) + val flussMap = SparkAsFlussMap(sparkMap, mapType) + + assertThat(flussMap.size()).isEqualTo(2) + val keyArray = flussMap.keyArray() + assertThat(keyArray.getString(0).toString).isEqualTo("ts1") + assertThat(keyArray.getString(1).toString).isEqualTo("ts2") + + val valueArray = flussMap.valueArray() + val ts1 = valueArray.getTimestampNtz(0, 6) + assertThat(ts1.toEpochMicros).isEqualTo(1634567890123456L) + val ts2 = valueArray.getTimestampNtz(1, 6) + assertThat(ts2.toEpochMicros).isEqualTo(1634567891123456L) + } +} diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussRowTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussRowTest.scala index 6a5240427d..3cdae4c8b8 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussRowTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussRowTest.scala @@ -21,7 +21,7 @@ import org.apache.fluss.spark.FlussSparkTestBase import org.apache.fluss.spark.util.TestUtils.SCHEMA import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.UTF8String import org.assertj.core.api.Assertions.assertThat @@ -48,13 +48,16 @@ class SparkAsFlussRowTest extends FlussSparkTestBase { UTF8String.fromString("test"), Timestamp.valueOf("2025-12-31 10:00:00").getTime * 1000, new GenericArrayData(Array(11.11f, 22.22f)), + ArrayBasedMapData.apply( + Array[Any](UTF8String.fromString("k1"), UTF8String.fromString("k2")), + Array(111, 222)), InternalRow.apply(123L, UTF8String.fromString("apache fluss")) )) row = new SparkAsFlussRow(SCHEMA).replace(data) } test("Fluss SparkAsFlussRow") { - assertThat(row.fieldCount).isEqualTo(13) + assertThat(row.fieldCount).isEqualTo(14) assertThat(row.getBoolean(0)).isEqualTo(true) assertThat(row.getByte(1)).isEqualTo(1.toByte) @@ -73,8 +76,14 @@ class SparkAsFlussRowTest extends FlussSparkTestBase { // test array type assertThat(row.getArray(11).toFloatArray).containsExactly(Array(11.11f, 22.22f): _*) + // test map type + assertThat(row.getMap(12).size()).isEqualTo(2) + assertThat(row.getMap(12).keyArray().getString(0).toString).isEqualTo("k1") + assertThat(row.getMap(12).keyArray().getString(1).toString).isEqualTo("k2") + assertThat(row.getMap(12).valueArray().toIntArray).containsExactly(Array(111, 222): _*) + // test row type - val nestedRow = row.getRow(12, 2) + val nestedRow = row.getRow(13, 2) assertThat(nestedRow.getFieldCount).isEqualTo(2) assertThat(nestedRow.getLong(0)).isEqualTo(123L) assertThat(nestedRow.getString(1).toString).isEqualTo("apache fluss") diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/util/TestUtils.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/util/TestUtils.scala index 704ca2ed50..bb06cd730b 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/util/TestUtils.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/util/TestUtils.scala @@ -40,6 +40,7 @@ object TestUtils { // StructField("date", DateType), StructField("c_timestamp", TimestampType), StructField("c_array", ArrayType(FloatType, containsNull = false)), + StructField("c_map", MapType(StringType, IntegerType, valueContainsNull = false)), StructField( "c_row", StructType(Seq(StructField("id", LongType), StructField("name", StringType))))