diff --git a/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java b/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java
index bfa2a299bc..b2e4a2cbc2 100644
--- a/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java
+++ b/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java
@@ -94,7 +94,13 @@ public static InternalArray copyArray(InternalArray from, DataType eleType) {
return new GenericArray(newArray);
}
- private static InternalMap copyMap(InternalMap map, DataType keyType, DataType valueType) {
+ /**
+ * Creates a copy of the given {@link InternalMap}.
+ *
+ *
This method is intended for internal use by the Fluss Spark adapter and is not part of the
+ * stable public API. Its signature and behavior may change without notice across releases.
+ */
+ public static InternalMap copyMap(InternalMap map, DataType keyType, DataType valueType) {
if (map instanceof BinaryMap) {
return ((BinaryMap) map).copy();
}
diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/DataConverter.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/DataConverter.scala
index cb9b309003..7d27771ce2 100644
--- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/DataConverter.scala
+++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/DataConverter.scala
@@ -76,8 +76,7 @@ object DataConverter {
}
def toSparkMap(flussMap: FlussInternalMap, mapType: FlussMapType): SparkMapData = {
- // TODO: support map type in fluss-spark
- throw new UnsupportedOperationException()
+ new FlussAsSparkMap(mapType).replace(flussMap)
}
def toSparkInternalRow(flussRow: FlussInternalRow, rowType: RowType): SparkInteralRow = {
diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkArray.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkArray.scala
index c67ae1e7ad..8a00b8ec91 100644
--- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkArray.scala
+++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkArray.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData
import org.apache.spark.sql.types.{DataType => SparkDataType, Decimal => SparkDecimal}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+/** Wraps a Fluss [[FlussInternalArray]] as a Spark [[SparkArrayData]]. */
class FlussAsSparkArray(elementType: FlussDataType) extends SparkArrayData {
var flussArray: FlussInternalArray = _
diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkMap.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkMap.scala
new file mode 100644
index 0000000000..0c04583ae4
--- /dev/null
+++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkMap.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.fluss.spark.row
+
+import org.apache.fluss.row.{InternalMap => FlussInternalMap}
+import org.apache.fluss.types.{MapType => FlussMapType}
+import org.apache.fluss.utils.InternalRowUtils
+
+import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData => SparkMapData}
+
+/** Wraps a Fluss [[FlussInternalMap]] as a Spark [[SparkMapData]]. */
+class FlussAsSparkMap(mapType: FlussMapType) extends SparkMapData {
+
+ var flussMap: FlussInternalMap = _
+
+ def replace(map: FlussInternalMap): SparkMapData = {
+ this.flussMap = map
+ this
+ }
+
+ override def numElements(): Int = flussMap.size()
+
+ override def copy(): SparkMapData = {
+ new FlussAsSparkMap(mapType)
+ .replace(InternalRowUtils.copyMap(flussMap, mapType.getKeyType, mapType.getValueType))
+ }
+
+ override def keyArray(): SparkArrayData = {
+ val keyType = mapType.getKeyType
+ new FlussAsSparkArray(keyType).replace(flussMap.keyArray())
+ }
+
+ override def valueArray(): SparkArrayData = {
+ val valueType = mapType.getValueType
+ new FlussAsSparkArray(valueType).replace(flussMap.valueArray())
+ }
+}
diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkRow.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkRow.scala
index 175900cd12..4744264704 100644
--- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkRow.scala
+++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/FlussAsSparkRow.scala
@@ -18,7 +18,7 @@
package org.apache.fluss.spark.row
import org.apache.fluss.row.{InternalRow => FlussInternalRow}
-import org.apache.fluss.types.{ArrayType => FlussArrayType, BinaryType => FlussBinaryType, LocalZonedTimestampType, RowType, TimestampType}
+import org.apache.fluss.types.{ArrayType => FlussArrayType, BinaryType => FlussBinaryType, LocalZonedTimestampType, MapType => FlussMapType, RowType, TimestampType}
import org.apache.fluss.utils.InternalRowUtils
import org.apache.spark.sql.catalyst.{InternalRow => SparkInteralRow}
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData, MapData
import org.apache.spark.sql.types.{DataType => SparkDataType, Decimal => SparkDecimal}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+/** Wraps a Fluss [[FlussInternalRow]] as a Spark [[SparkInteralRow]]. */
class FlussAsSparkRow(rowType: RowType) extends SparkInteralRow {
val fieldCount: Int = rowType.getFieldCount
@@ -104,8 +105,9 @@ class FlussAsSparkRow(rowType: RowType) extends SparkInteralRow {
}
override def getMap(ordinal: Int): SparkMapData = {
- // TODO: support map type in fluss-spark
- throw new UnsupportedOperationException()
+ val mapType = rowType.getTypeAt(ordinal).asInstanceOf[FlussMapType]
+ val flussMap = row.getMap(ordinal)
+ DataConverter.toSparkMap(flussMap, mapType)
}
override def get(ordinal: Int, dataType: SparkDataType): AnyRef = {
diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussArray.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussArray.scala
index 17d45ec028..277f3a295c 100644
--- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussArray.scala
+++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussArray.scala
@@ -20,7 +20,7 @@ package org.apache.fluss.spark.row
import org.apache.fluss.row.{BinaryString, Decimal, InternalArray => FlussInternalArray, InternalMap, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz}
import org.apache.spark.sql.catalyst.util.{ArrayData => SparkArrayData}
-import org.apache.spark.sql.types.{ArrayType => SparkArrayType, DataType => SparkDataType, StructType}
+import org.apache.spark.sql.types.{ArrayType => SparkArrayType, DataType => SparkDataType, MapType => SparkMapType, StructType}
/** Wraps a Spark [[SparkArrayData]] as a Fluss [[FlussInternalArray]]. */
class SparkAsFlussArray(arrayData: SparkArrayData, elementType: SparkDataType)
@@ -125,13 +125,14 @@ class SparkAsFlussArray(arrayData: SparkArrayData, elementType: SparkDataType)
arrayData.getArray(pos),
elementType.asInstanceOf[SparkArrayType].elementType)
+ /** Returns the map value at the given position. */
+ override def getMap(pos: Int): InternalMap = {
+ val mapType = elementType.asInstanceOf[SparkMapType]
+ SparkAsFlussMap(arrayData.getMap(pos), mapType)
+ }
+
/** Returns the row value at the given position. */
override def getRow(pos: Int, numFields: Int): FlussInternalRow =
new SparkAsFlussRow(elementType.asInstanceOf[StructType])
.replace(arrayData.getStruct(pos, numFields))
-
- /** Returns the map value at the given position. */
- override def getMap(pos: Int): InternalMap = {
- throw new UnsupportedOperationException()
- }
}
diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussMap.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussMap.scala
new file mode 100644
index 0000000000..bd04300207
--- /dev/null
+++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussMap.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.fluss.spark.row
+
+import org.apache.fluss.row.{InternalArray => FlussInternalArray, InternalMap => FlussInternalMap}
+
+import org.apache.spark.sql.catalyst.util.{MapData => SparkMapData}
+import org.apache.spark.sql.types.{DataType => SparkDataType, MapType => SparkMapType}
+
+/** Wraps a Spark [[SparkMapData]] as a Fluss [[FlussInternalMap]]. */
+class SparkAsFlussMap(mapData: SparkMapData, keyType: SparkDataType, valueType: SparkDataType)
+ extends FlussInternalMap
+ with Serializable {
+
+ /** Returns the number of key-value mappings in this map. */
+ override def size(): Int = mapData.numElements()
+
+ /**
+ * Returns an array view of the keys contained in this map.
+ *
+ *
A key-value pair has the same index in the key array and value array.
+ */
+ override def keyArray(): FlussInternalArray = {
+ new SparkAsFlussArray(mapData.keyArray(), keyType)
+ }
+
+ /**
+ * Returns an array view of the values contained in this map.
+ *
+ *
A key-value pair has the same index in the key array and value array.
+ */
+ override def valueArray(): FlussInternalArray = {
+ new SparkAsFlussArray(mapData.valueArray(), valueType)
+ }
+}
+
+object SparkAsFlussMap {
+ def apply(mapData: SparkMapData, mapType: SparkMapType): SparkAsFlussMap =
+ new SparkAsFlussMap(mapData, mapType.keyType, mapType.valueType)
+}
diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussRow.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussRow.scala
index 3a5c9613c2..368a7c973d 100644
--- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussRow.scala
+++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/row/SparkAsFlussRow.scala
@@ -20,7 +20,7 @@ package org.apache.fluss.spark.row
import org.apache.fluss.row.{BinaryString, Decimal, InternalMap, InternalRow => FlussInternalRow, TimestampLtz, TimestampNtz}
import org.apache.spark.sql.catalyst.{InternalRow => SparkInternalRow}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{MapType => SparkMapType, StructType}
/** Wraps a Spark [[SparkInternalRow]] as a Fluss [[FlussInternalRow]]. */
class SparkAsFlussRow(schema: StructType) extends FlussInternalRow with Serializable {
@@ -127,6 +127,8 @@ class SparkAsFlussRow(schema: StructType) extends FlussInternalRow with Serializ
/** Returns the map value at the given position. */
override def getMap(pos: Int): InternalMap = {
- throw new UnsupportedOperationException()
+ val sparkMapData = row.getMap(pos)
+ val mapType = schema.fields(pos).dataType.asInstanceOf[SparkMapType]
+ SparkAsFlussMap(sparkMapData, mapType)
}
}
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala
index 5f9613cbc3..2342d658c2 100644
--- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala
@@ -213,18 +213,18 @@ class SparkLogTableReadTest extends FlussSparkTestBase {
test("Spark Read: nested data types table") {
withTable("t") {
- // TODO: support map type
sql(s"""
|CREATE TABLE $DEFAULT_DATABASE.t (
|id INT,
|arr ARRAY,
+ |map MAP,
|struct_col STRUCT
|)""".stripMargin)
sql(s"""
|INSERT INTO $DEFAULT_DATABASE.t VALUES
- |(1, ARRAY(1, 2, 3), STRUCT(100, 'nested_value')),
- |(2, ARRAY(7, 8, 9), STRUCT(200, 'nested_value2'))
+ |(1, ARRAY(1, 2, 3), MAP("k1", 111, "k2", 222), STRUCT(100, 'nested_value')),
+ |(2, ARRAY(7, 8, 9), MAP("k1", 333, "k2", 444), STRUCT(200, 'nested_value2'))
|""".stripMargin)
checkAnswer(
@@ -232,10 +232,12 @@ class SparkLogTableReadTest extends FlussSparkTestBase {
Row(
1,
Seq(1, 2, 3),
+ Map("k1" -> 111, "k2" -> 222),
Row(100, "nested_value")
) :: Row(
2,
Seq(7, 8, 9),
+ Map("k1" -> 333, "k2" -> 444),
Row(200, "nested_value2")
) :: Nil
)
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkWriteTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkWriteTest.scala
index 179c05e801..a2382c1616 100644
--- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkWriteTest.scala
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkWriteTest.scala
@@ -47,7 +47,9 @@ class SparkWriteTest extends FlussSparkTestBase {
| 1234567.89, 12345678900987654321.12,
| "test",
| TO_TIMESTAMP('2025-12-31 10:00:00', 'yyyy-MM-dd kk:mm:ss'),
- | array(11.11F, 22.22F), struct(123L, "apache fluss")
+ | array(11.11F, 22.22F),
+ | map("k1", 111, "k2", 222),
+ | struct(123L, "apache fluss")
|)
|""".stripMargin)
@@ -56,7 +58,7 @@ class SparkWriteTest extends FlussSparkTestBase {
assertThat(rows.length).isEqualTo(1)
val row = rows.head
- assertThat(row.getFieldCount).isEqualTo(13)
+ assertThat(row.getFieldCount).isEqualTo(14)
assertThat(row.getBoolean(0)).isEqualTo(true)
assertThat(row.getByte(1)).isEqualTo(1.toByte)
assertThat(row.getShort(2)).isEqualTo(10.toShort)
@@ -71,7 +73,12 @@ class SparkWriteTest extends FlussSparkTestBase {
assertThat(row.getTimestampLtz(10, 6).toInstant)
.isEqualTo(Timestamp.valueOf("2025-12-31 10:00:00.0").toInstant)
assertThat(row.getArray(11).toFloatArray).containsExactly(Array(11.11f, 22.22f): _*)
- val nestedRow = row.getRow(12, 2)
+ val mapData = row.getMap(12)
+ assertThat(mapData.size()).isEqualTo(2)
+ assertThat(mapData.keyArray().getString(0).toString).isEqualTo("k1")
+ assertThat(mapData.keyArray().getString(1).toString).isEqualTo("k2")
+ assertThat(mapData.valueArray().toIntArray).containsExactly(Array(111, 222): _*)
+ val nestedRow = row.getRow(13, 2)
assertThat(nestedRow.getFieldCount).isEqualTo(2)
assertThat(nestedRow.getLong(0)).isEqualTo(123L)
assertThat(nestedRow.getString(1).toString).isEqualTo("apache fluss")
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/DataConverterTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/DataConverterTest.scala
index 6866ba7b60..d7cfc9d2f9 100644
--- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/DataConverterTest.scala
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/DataConverterTest.scala
@@ -17,14 +17,16 @@
package org.apache.fluss.spark.row
-import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericRow, TimestampLtz, TimestampNtz}
-import org.apache.fluss.types.{ArrayType, CharType, DataTypes, DecimalType, LocalZonedTimestampType, RowType, TimestampType}
+import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericMap, GenericRow, TimestampLtz, TimestampNtz}
+import org.apache.fluss.types.{ArrayType, CharType, DataTypes, DecimalType, LocalZonedTimestampType, MapType, RowType, TimestampType}
import org.apache.spark.sql.types.{Decimal => SparkDecimal}
import org.apache.spark.unsafe.types.UTF8String
import org.assertj.core.api.Assertions.assertThat
import org.scalatest.funsuite.AnyFunSuite
+import scala.collection.JavaConverters._
+
class DataConverterTest extends AnyFunSuite {
test("toSparkObject: null value") {
@@ -274,6 +276,28 @@ class DataConverterTest extends AnyFunSuite {
assertThat(result.asInstanceOf[Long]).isEqualTo(2000000000L) // microseconds
}
+ test("toSparkMap: Map type") {
+ val flussMap = new GenericMap(
+ Map(
+ BinaryString.fromString("a") -> Integer.valueOf(1),
+ BinaryString.fromString("b") -> Integer.valueOf(2)).asJava)
+ val mapType = new MapType(DataTypes.STRING, DataTypes.INT)
+ val sparkMap = DataConverter.toSparkMap(flussMap, mapType)
+ assertThat(sparkMap.numElements()).isEqualTo(2)
+
+ val keyArray = sparkMap.keyArray()
+ val valueArray = sparkMap.valueArray()
+ assertThat(keyArray.numElements()).isEqualTo(2)
+ assertThat(valueArray.numElements()).isEqualTo(2)
+
+ val actualMap =
+ (0 until sparkMap.numElements())
+ .map(i => keyArray.getUTF8String(i).toString -> valueArray.getInt(i))
+ .toMap
+
+ assertThat(actualMap).isEqualTo(Map("a" -> 1, "b" -> 2))
+ }
+
test("toSparkObject: ROW type") {
val rowType = RowType
.builder()
@@ -288,10 +312,4 @@ class DataConverterTest extends AnyFunSuite {
assertThat(result).isNotNull()
assertThat(result.asInstanceOf[FlussAsSparkRow].getInt(0)).isEqualTo(42)
}
-
- test("toSparkMap: unsupported") {
- assertThrows[UnsupportedOperationException] {
- DataConverter.toSparkMap(null, null)
- }
- }
}
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkArrayTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkArrayTest.scala
index 029104e5bf..27a4bd529e 100644
--- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkArrayTest.scala
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkArrayTest.scala
@@ -308,14 +308,17 @@ class FlussAsSparkArrayTest extends AnyFunSuite {
assertThat(sparkInnerArray2.getInt(2)).isEqualTo(6)
}
- test("getMap: unsupported operation") {
+ test("getMap: read map array") {
val mapType = DataTypes.MAP(DataTypes.INT, DataTypes.STRING)
- val flussArray = GenericArray.of(new GenericMap(Map(1 -> "map").asJava))
+ val innerMap =
+ new GenericMap(Map(Integer.valueOf(1) -> BinaryString.fromString("value1")).asJava)
+ val flussArray = new GenericArray(Array[Object](innerMap))
val sparkArray = new FlussAsSparkArray(mapType).replace(flussArray)
- assertThrows[UnsupportedOperationException] {
- sparkArray.getMap(0)
- }
+ val sparkMap = sparkArray.getMap(0)
+ assertThat(sparkMap.numElements()).isEqualTo(1)
+ assertThat(sparkMap.keyArray().getInt(0)).isEqualTo(1)
+ assertThat(sparkMap.valueArray().getUTF8String(0).toString).isEqualTo("value1")
}
test("getInterval: unsupported operation") {
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkMapTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkMapTest.scala
new file mode 100644
index 0000000000..91d492dbda
--- /dev/null
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkMapTest.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.fluss.spark.row
+
+import org.apache.fluss.row.{BinaryString, GenericArray, GenericMap, GenericRow}
+import org.apache.fluss.spark.FlussSparkTestBase
+import org.apache.fluss.types.{ArrayType, BigIntType, CharType, IntType, MapType, RowType, StringType}
+
+import org.assertj.core.api.Assertions.assertThat
+
+import scala.collection.JavaConverters._
+
+class FlussAsSparkMapTest extends FlussSparkTestBase {
+
+ test("numElements: empty map") {
+ val flussMap = new GenericMap(Map.empty[Object, Object].asJava)
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType()))
+ .replace(flussMap)
+
+ assertThat(sparkMap.numElements()).isEqualTo(0)
+ }
+
+ test("numElements: non-empty map") {
+ val flussMap = createSimpleMap()
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType()))
+ .replace(flussMap)
+
+ assertThat(sparkMap.numElements()).isEqualTo(3)
+ }
+
+ test("keyArray: empty map") {
+ val flussMap = new GenericMap(Map.empty[Object, Object].asJava)
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType()))
+ .replace(flussMap)
+
+ val sparkKeyArray = sparkMap.keyArray()
+ assertThat(sparkKeyArray.numElements()).isEqualTo(0)
+ }
+
+ test("keyArray: non-empty map") {
+ val flussMap = createSimpleMap()
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType()))
+ .replace(flussMap)
+
+ val sparkKeyArray = sparkMap.keyArray()
+ assertThat(sparkKeyArray.numElements()).isEqualTo(3)
+ // Keys are in insertion order for GenericMap
+ val key0 = sparkKeyArray.getUTF8String(0).toString
+ val key1 = sparkKeyArray.getUTF8String(1).toString
+ val key2 = sparkKeyArray.getUTF8String(2).toString
+ assertThat(Set(key0, key1, key2)).isEqualTo(Set("key1", "key2", "key3"))
+ }
+
+ test("valueArray: empty map") {
+ val flussMap = new GenericMap(Map.empty[Object, Object].asJava)
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType()))
+ .replace(flussMap)
+
+ val sparkValueArray = sparkMap.valueArray()
+ assertThat(sparkValueArray.numElements()).isEqualTo(0)
+ }
+
+ test("valueArray: non-empty map") {
+ val flussMap = createSimpleMap()
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType()))
+ .replace(flussMap)
+
+ val sparkValueArray = sparkMap.valueArray()
+ assertThat(sparkValueArray.numElements()).isEqualTo(3)
+ assertThat(sparkValueArray.getInt(0)).isEqualTo(100)
+ assertThat(sparkValueArray.getInt(1)).isEqualTo(200)
+ assertThat(sparkValueArray.getInt(2)).isEqualTo(300)
+ }
+
+ test("copy: creates deep copy") {
+ val flussMap = createSimpleMap()
+ val originalSparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new IntType()))
+ .replace(flussMap)
+ val copiedSparkMap = originalSparkMap.copy()
+
+ assertThat(copiedSparkMap.numElements()).isEqualTo(3)
+ }
+
+ test("integration: map with nested array") {
+ val flussMap = createMapWithNestedArrays()
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), new ArrayType(new IntType())))
+ .replace(flussMap)
+
+ assertThat(sparkMap.numElements()).isEqualTo(2)
+
+ val sparkKeyArray = sparkMap.keyArray()
+ val key0 = sparkKeyArray.getUTF8String(0).toString
+ val key1 = sparkKeyArray.getUTF8String(1).toString
+ assertThat(Set(key0, key1)).isEqualTo(Set("arr1", "arr2"))
+
+ val sparkValueArray = sparkMap.valueArray()
+ // Check that we have 2 arrays
+ assertThat(sparkValueArray.numElements()).isEqualTo(2)
+ }
+
+ test("integration: map with nested row") {
+ val flussMap = createMapWithNestedRows()
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), createSimpleRowType()))
+ .replace(flussMap)
+
+ assertThat(sparkMap.numElements()).isEqualTo(2)
+
+ val sparkKeyArray = sparkMap.keyArray()
+ val key0 = sparkKeyArray.getUTF8String(0).toString
+ val key1 = sparkKeyArray.getUTF8String(1).toString
+ assertThat(Set(key0, key1)).isEqualTo(Set("row1", "row2"))
+ }
+
+ test("integration: map with nested map") {
+ val flussMap = createMapWithNestedMaps()
+ val sparkMap = new FlussAsSparkMap(new MapType(new CharType(10), createSimpleMapType()))
+ .replace(flussMap)
+
+ assertThat(sparkMap.numElements()).isEqualTo(2)
+
+ val sparkKeyArray = sparkMap.keyArray()
+ val key0 = sparkKeyArray.getUTF8String(0).toString
+ val key1 = sparkKeyArray.getUTF8String(1).toString
+ assertThat(Set(key0, key1)).isEqualTo(Set("map1", "map2"))
+
+ val sparkValueArray = sparkMap.valueArray()
+ assertThat(sparkValueArray.numElements()).isEqualTo(2)
+ }
+
+ private def createSimpleMap(): GenericMap = {
+ new GenericMap(
+ Map(
+ BinaryString.fromString("key1") -> Integer.valueOf(100),
+ BinaryString.fromString("key2") -> Integer.valueOf(200),
+ BinaryString.fromString("key3") -> Integer.valueOf(300)
+ ).asJava
+ )
+ }
+
+ private def createMapWithNestedArrays(): GenericMap = {
+ new GenericMap(
+ Map(
+ BinaryString.fromString("arr1") -> new GenericArray(
+ Array[Object](Integer.valueOf(1), Integer.valueOf(2), Integer.valueOf(3))),
+ BinaryString.fromString("arr2") -> new GenericArray(
+ Array[Object](Integer.valueOf(4), Integer.valueOf(5), Integer.valueOf(6)))
+ ).asJava
+ )
+ }
+
+ private def createMapWithNestedRows(): GenericMap = {
+ val row1 = new GenericRow(2)
+ row1.setField(0, java.lang.Long.valueOf(100L))
+ row1.setField(1, BinaryString.fromString("value1"))
+
+ val row2 = new GenericRow(2)
+ row2.setField(0, java.lang.Long.valueOf(200L))
+ row2.setField(1, BinaryString.fromString("value2"))
+
+ new GenericMap(
+ Map(
+ BinaryString.fromString("row1") -> row1,
+ BinaryString.fromString("row2") -> row2
+ ).asJava
+ )
+ }
+
+ private def createMapWithNestedMaps(): GenericMap = {
+ val innerMap1 = new GenericMap(
+ Map(
+ BinaryString.fromString("inner1") -> Integer.valueOf(10),
+ BinaryString.fromString("inner2") -> Integer.valueOf(20)
+ ).asJava
+ )
+
+ val innerMap2 = new GenericMap(
+ Map(
+ BinaryString.fromString("inner3") -> Integer.valueOf(30),
+ BinaryString.fromString("inner4") -> Integer.valueOf(40)
+ ).asJava
+ )
+
+ new GenericMap(
+ Map(
+ BinaryString.fromString("map1") -> innerMap1,
+ BinaryString.fromString("map2") -> innerMap2
+ ).asJava
+ )
+ }
+
+ private def createSimpleRowType(): RowType = {
+ val fields = java.util.Arrays.asList(
+ new org.apache.fluss.types.DataField("field1", new BigIntType()),
+ new org.apache.fluss.types.DataField("field2", new StringType())
+ )
+ new RowType(fields)
+ }
+
+ private def createSimpleMapType(): MapType = {
+ new MapType(new CharType(10), new IntType())
+ }
+}
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkRowTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkRowTest.scala
index 03ccad4ef6..697fbd2f4d 100644
--- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkRowTest.scala
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/FlussAsSparkRowTest.scala
@@ -17,12 +17,14 @@
package org.apache.fluss.spark.row
-import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericRow, TimestampLtz, TimestampNtz}
+import org.apache.fluss.row.{BinaryString, Decimal => FlussDecimal, GenericArray, GenericMap, GenericRow, TimestampLtz, TimestampNtz}
import org.apache.fluss.types.{ArrayType, BinaryType, DataTypes, LocalZonedTimestampType, RowType, TimestampType}
import org.assertj.core.api.Assertions.assertThat
import org.scalatest.funsuite.AnyFunSuite
+import scala.collection.JavaConverters._
+
class FlussAsSparkRowTest extends AnyFunSuite {
test("basic row operations: numFields and fieldCount") {
@@ -339,20 +341,24 @@ class FlussAsSparkRowTest extends AnyFunSuite {
assertThat(sparkArray.getInt(4)).isEqualTo(5)
}
- test("getMap: unsupported operation") {
+ test("getMap: read map field") {
+ val mapType = DataTypes.MAP(DataTypes.INT, DataTypes.STRING)
val rowType = RowType
.builder()
- .field("dummy", DataTypes.INT)
+ .field("map_col", mapType)
.build()
+ val flussMap =
+ new GenericMap(Map(Integer.valueOf(1) -> BinaryString.fromString("value1")).asJava)
val flussRow = new GenericRow(1)
- flussRow.setField(0, Integer.valueOf(1))
+ flussRow.setField(0, flussMap)
val sparkRow = new FlussAsSparkRow(rowType).replace(flussRow)
- assertThrows[UnsupportedOperationException] {
- sparkRow.getMap(0)
- }
+ val sparkMap = sparkRow.getMap(0)
+ assertThat(sparkMap.numElements()).isEqualTo(1)
+ assertThat(sparkMap.keyArray().getInt(0)).isEqualTo(1)
+ assertThat(sparkMap.valueArray().getUTF8String(0).toString).isEqualTo("value1")
}
test("getInterval: unsupported operation") {
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussArrayTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussArrayTest.scala
index fb4ee6308a..74f07a58d2 100644
--- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussArrayTest.scala
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussArrayTest.scala
@@ -20,7 +20,7 @@ package org.apache.fluss.spark.row
import org.apache.fluss.spark.FlussSparkTestBase
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{BooleanType, ByteType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.assertj.core.api.Assertions.assertThat
@@ -129,4 +129,42 @@ class SparkAsFlussArrayTest extends FlussSparkTestBase {
}
}
}
+
+ test("Fluss SparkAsFlussArray: Map Type") {
+ val mapType = org.apache.spark.sql.types.MapType(StringType, IntegerType)
+ val data = Array(
+ ArrayBasedMapData.apply(
+ Array(UTF8String.fromString("key1"), UTF8String.fromString("key2"))
+ .asInstanceOf[Array[Any]],
+ Array(100, 200)),
+ ArrayBasedMapData.apply(
+ Array(UTF8String.fromString("key3"), UTF8String.fromString("key4"))
+ .asInstanceOf[Array[Any]],
+ Array(300, 400)),
+ null
+ )
+ val sparkArrayData = new GenericArrayData(data)
+ val flussArray = new SparkAsFlussArray(sparkArrayData, mapType)
+
+ assertThat(flussArray.size()).isEqualTo(3)
+ assertThat(flussArray.isNullAt(0)).isFalse()
+ assertThat(flussArray.isNullAt(1)).isFalse()
+ assertThat(flussArray.isNullAt(2)).isTrue()
+
+ // Check first map
+ val map1 = flussArray.getMap(0)
+ assertThat(map1.size()).isEqualTo(2)
+ assertThat(map1.keyArray().getString(0).toString).isEqualTo("key1")
+ assertThat(map1.keyArray().getString(1).toString).isEqualTo("key2")
+ assertThat(map1.valueArray().getInt(0)).isEqualTo(100)
+ assertThat(map1.valueArray().getInt(1)).isEqualTo(200)
+
+ // Check second map
+ val map2 = flussArray.getMap(1)
+ assertThat(map2.size()).isEqualTo(2)
+ assertThat(map2.keyArray().getString(0).toString).isEqualTo("key3")
+ assertThat(map2.keyArray().getString(1).toString).isEqualTo("key4")
+ assertThat(map2.valueArray().getInt(0)).isEqualTo(300)
+ assertThat(map2.valueArray().getInt(1)).isEqualTo(400)
+ }
}
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussMapTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussMapTest.scala
new file mode 100644
index 0000000000..383ccf0e54
--- /dev/null
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussMapTest.scala
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.fluss.spark.row
+
+import org.apache.fluss.spark.FlussSparkTestBase
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
+import org.apache.spark.sql.types.{IntegerType, MapType, StringType, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+import org.assertj.core.api.Assertions.assertThat
+
+class SparkAsFlussMapTest extends FlussSparkTestBase {
+
+ // Helper method to convert String array to UTF8String array
+ private def toUTF8Strings(strings: String*): Array[Any] = {
+ strings.map(UTF8String.fromString).toArray
+ }
+
+ test("size: empty map") {
+ val sparkMap = new ArrayBasedMapData(
+ ArrayData.toArrayData(Array.empty[Any]),
+ ArrayData.toArrayData(Array.empty[Any]))
+ val mapType = MapType(StringType, IntegerType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(0)
+ }
+
+ test("size: non-empty map") {
+ val keys = toUTF8Strings("key1", "key2", "key3")
+ val values = Array(100, 200, 300)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, IntegerType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(3)
+ }
+
+ test("keyArray: empty map") {
+ val sparkMap = new ArrayBasedMapData(
+ ArrayData.toArrayData(Array.empty[Any]),
+ ArrayData.toArrayData(Array.empty[Any]))
+ val mapType = MapType(StringType, IntegerType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ val keyArray = flussMap.keyArray()
+ assertThat(keyArray.size()).isEqualTo(0)
+ }
+
+ test("keyArray: non-empty map") {
+ val keys = toUTF8Strings("key1", "key2", "key3")
+ val values = Array(100, 200, 300)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, IntegerType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ val keyArray = flussMap.keyArray()
+ assertThat(keyArray.size()).isEqualTo(3)
+ assertThat(keyArray.getString(0).toString).isEqualTo("key1")
+ assertThat(keyArray.getString(1).toString).isEqualTo("key2")
+ assertThat(keyArray.getString(2).toString).isEqualTo("key3")
+ }
+
+ test("valueArray: empty map") {
+ val sparkMap = new ArrayBasedMapData(
+ ArrayData.toArrayData(Array.empty[Any]),
+ ArrayData.toArrayData(Array.empty[Any]))
+ val mapType = MapType(StringType, IntegerType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.size()).isEqualTo(0)
+ }
+
+ test("valueArray: non-empty map") {
+ val keys = toUTF8Strings("key1", "key2", "key3")
+ val values = Array(100, 200, 300)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, IntegerType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.size()).isEqualTo(3)
+ assertThat(valueArray.getInt(0)).isEqualTo(100)
+ assertThat(valueArray.getInt(1)).isEqualTo(200)
+ assertThat(valueArray.getInt(2)).isEqualTo(300)
+ }
+
+ test("integration: map with nested array") {
+ val keys = toUTF8Strings("array1", "array2")
+ val values = Array(
+ ArrayData.toArrayData(Array(1, 2, 3)),
+ ArrayData.toArrayData(Array(4, 5, 6))
+ )
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.ArrayType(IntegerType))
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+
+ val keyArray = flussMap.keyArray()
+ assertThat(keyArray.getString(0).toString).isEqualTo("array1")
+ assertThat(keyArray.getString(1).toString).isEqualTo("array2")
+
+ val valueArray = flussMap.valueArray()
+ val array1 = valueArray.getArray(0)
+ assertThat(array1.size()).isEqualTo(3)
+ assertThat(array1.getInt(0)).isEqualTo(1)
+ assertThat(array1.getInt(1)).isEqualTo(2)
+ assertThat(array1.getInt(2)).isEqualTo(3)
+
+ val array2 = valueArray.getArray(1)
+ assertThat(array2.size()).isEqualTo(3)
+ assertThat(array2.getInt(0)).isEqualTo(4)
+ assertThat(array2.getInt(1)).isEqualTo(5)
+ assertThat(array2.getInt(2)).isEqualTo(6)
+ }
+
+ test("integration: map with nested row") {
+ val keys = toUTF8Strings("row1", "row2")
+ val values = Array(
+ InternalRow.apply(UTF8String.fromString("value1"), 100),
+ InternalRow.apply(UTF8String.fromString("value2"), 200)
+ )
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val structType = new StructType().add("name", StringType).add("value", IntegerType)
+ val mapType = MapType(StringType, structType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+
+ val keyArray = flussMap.keyArray()
+ assertThat(keyArray.getString(0).toString).isEqualTo("row1")
+ assertThat(keyArray.getString(1).toString).isEqualTo("row2")
+
+ val valueArray = flussMap.valueArray()
+ val row1 = valueArray.getRow(0, 2)
+ assertThat(row1.getString(0).toString).isEqualTo("value1")
+ assertThat(row1.getInt(1)).isEqualTo(100)
+
+ val row2 = valueArray.getRow(1, 2)
+ assertThat(row2.getString(0).toString).isEqualTo("value2")
+ assertThat(row2.getInt(1)).isEqualTo(200)
+ }
+
+ test("integration: map with nested map") {
+ val keys = toUTF8Strings("map1", "map2")
+ val values = Array(
+ new ArrayBasedMapData(
+ ArrayData.toArrayData(toUTF8Strings("inner1", "inner2")),
+ ArrayData.toArrayData(Array(10, 20))),
+ new ArrayBasedMapData(
+ ArrayData.toArrayData(toUTF8Strings("inner3", "inner4")),
+ ArrayData.toArrayData(Array(30, 40)))
+ )
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val innerMapType = MapType(StringType, IntegerType)
+ val mapType = MapType(StringType, innerMapType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+
+ val keyArray = flussMap.keyArray()
+ assertThat(keyArray.getString(0).toString).isEqualTo("map1")
+ assertThat(keyArray.getString(1).toString).isEqualTo("map2")
+
+ val valueArray = flussMap.valueArray()
+ val innerMap1 = valueArray.getMap(0)
+ assertThat(innerMap1.size()).isEqualTo(2)
+ val innerMap1Keys = innerMap1.keyArray()
+ assertThat(innerMap1Keys.getString(0).toString).isEqualTo("inner1")
+ assertThat(innerMap1Keys.getString(1).toString).isEqualTo("inner2")
+ val innerMap1Values = innerMap1.valueArray()
+ assertThat(innerMap1Values.getInt(0)).isEqualTo(10)
+ assertThat(innerMap1Values.getInt(1)).isEqualTo(20)
+
+ val innerMap2 = valueArray.getMap(1)
+ assertThat(innerMap2.size()).isEqualTo(2)
+ val innerMap2Keys = innerMap2.keyArray()
+ assertThat(innerMap2Keys.getString(0).toString).isEqualTo("inner3")
+ assertThat(innerMap2Keys.getString(1).toString).isEqualTo("inner4")
+ val innerMap2Values = innerMap2.valueArray()
+ assertThat(innerMap2Values.getInt(0)).isEqualTo(30)
+ assertThat(innerMap2Values.getInt(1)).isEqualTo(40)
+ }
+
+ test("basic accessors return expected keys and values") {
+ val keys = toUTF8Strings("key1", "key2")
+ val values = Array(100, 200)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, IntegerType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ // Verify the wrapped map exposes the expected size, keys, and values.
+ assertThat(flussMap.size()).isEqualTo(2)
+ assertThat(flussMap.keyArray().getString(0).toString).isEqualTo("key1")
+ assertThat(flussMap.valueArray().getInt(1)).isEqualTo(200)
+ }
+
+ test("map with integer values") {
+ val keys = toUTF8Strings("int_key1", "int_key2")
+ val values = Array(100, 200)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, IntegerType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.getInt(0)).isEqualTo(100)
+ assertThat(valueArray.getInt(1)).isEqualTo(200)
+ }
+
+ test("map with float values") {
+ val keys = toUTF8Strings("float_key1", "float_key2")
+ val values = Array(12.34f, 56.78f)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.FloatType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.getFloat(0)).isEqualTo(12.34f)
+ assertThat(valueArray.getFloat(1)).isEqualTo(56.78f)
+ }
+
+ test("map with double values") {
+ val keys = toUTF8Strings("double_key1", "double_key2")
+ val values = Array(56.78, 90.12)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.DoubleType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.getDouble(0)).isEqualTo(56.78)
+ assertThat(valueArray.getDouble(1)).isEqualTo(90.12)
+ }
+
+ test("map with long values") {
+ val keys = toUTF8Strings("long_key1", "long_key2")
+ val values = Array(1000L, 2000L)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.LongType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.getLong(0)).isEqualTo(1000L)
+ assertThat(valueArray.getLong(1)).isEqualTo(2000L)
+ }
+
+ test("map with boolean values") {
+ val keys = toUTF8Strings("bool_key1", "bool_key2")
+ val values = Array(true, false)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.BooleanType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.getBoolean(0)).isTrue()
+ assertThat(valueArray.getBoolean(1)).isFalse()
+ }
+
+ test("map with byte values") {
+ val keys = toUTF8Strings("byte_key1", "byte_key2")
+ val values = Array(127.toByte, 64.toByte)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.ByteType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.getByte(0)).isEqualTo(127.toByte)
+ assertThat(valueArray.getByte(1)).isEqualTo(64.toByte)
+ }
+
+ test("map with short values") {
+ val keys = toUTF8Strings("short_key1", "short_key2")
+ val values = Array(1000.toShort, 2000.toShort)
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.ShortType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.getShort(0)).isEqualTo(1000.toShort)
+ assertThat(valueArray.getShort(1)).isEqualTo(2000.toShort)
+ }
+
+ test("map with null values") {
+ val keys = toUTF8Strings("key1", "key2", "key3")
+ val values = Array[Any](UTF8String.fromString("value1"), null, UTF8String.fromString("value3"))
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, StringType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(3)
+ val valueArray = flussMap.valueArray()
+ assertThat(valueArray.isNullAt(0)).isFalse()
+ assertThat(valueArray.isNullAt(1)).isTrue()
+ assertThat(valueArray.isNullAt(2)).isFalse()
+ assertThat(valueArray.getString(0).toString).isEqualTo("value1")
+ assertThat(valueArray.getString(2).toString).isEqualTo("value3")
+ }
+
+ test("map with numeric keys") {
+ val keys = Array(1, 2, 3)
+ val values = toUTF8Strings("value1", "value2", "value3")
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(IntegerType, StringType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(3)
+ val keyArray = flussMap.keyArray()
+ val valueArray = flussMap.valueArray()
+ assertThat(keyArray.getInt(0)).isEqualTo(1)
+ assertThat(keyArray.getInt(1)).isEqualTo(2)
+ assertThat(keyArray.getInt(2)).isEqualTo(3)
+ assertThat(valueArray.getString(0).toString).isEqualTo("value1")
+ assertThat(valueArray.getString(1).toString).isEqualTo("value2")
+ assertThat(valueArray.getString(2).toString).isEqualTo("value3")
+ }
+
+ test("map with complex nested structure: array of rows") {
+ val keys = toUTF8Strings("data1", "data2")
+ val values = Array(
+ ArrayData.toArrayData(
+ Array(
+ InternalRow.apply(UTF8String.fromString("name1"), 100),
+ InternalRow.apply(UTF8String.fromString("name2"), 200)
+ )),
+ ArrayData.toArrayData(
+ Array(
+ InternalRow.apply(UTF8String.fromString("name3"), 300),
+ InternalRow.apply(UTF8String.fromString("name4"), 400)
+ ))
+ )
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val structType = new StructType().add("name", StringType).add("value", IntegerType)
+ val arrayType = org.apache.spark.sql.types.ArrayType(structType)
+ val mapType = MapType(StringType, arrayType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val keyArray = flussMap.keyArray()
+ assertThat(keyArray.getString(0).toString).isEqualTo("data1")
+ assertThat(keyArray.getString(1).toString).isEqualTo("data2")
+
+ val valueArray = flussMap.valueArray()
+ val array1 = valueArray.getArray(0)
+ assertThat(array1.size()).isEqualTo(2)
+ val row1_0 = array1.getRow(0, 2)
+ assertThat(row1_0.getString(0).toString).isEqualTo("name1")
+ assertThat(row1_0.getInt(1)).isEqualTo(100)
+ val row1_1 = array1.getRow(1, 2)
+ assertThat(row1_1.getString(0).toString).isEqualTo("name2")
+ assertThat(row1_1.getInt(1)).isEqualTo(200)
+
+ val array2 = valueArray.getArray(1)
+ assertThat(array2.size()).isEqualTo(2)
+ val row2_0 = array2.getRow(0, 2)
+ assertThat(row2_0.getString(0).toString).isEqualTo("name3")
+ assertThat(row2_0.getInt(1)).isEqualTo(300)
+ val row2_1 = array2.getRow(1, 2)
+ assertThat(row2_1.getString(0).toString).isEqualTo("name4")
+ assertThat(row2_1.getInt(1)).isEqualTo(400)
+ }
+
+ test("map with decimal values") {
+ val keys = toUTF8Strings("dec1", "dec2")
+ val values = Array(
+ org.apache.spark.sql.types.Decimal(123.45),
+ org.apache.spark.sql.types.Decimal(678.90)
+ )
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.DecimalType(10, 2))
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val keyArray = flussMap.keyArray()
+ assertThat(keyArray.getString(0).toString).isEqualTo("dec1")
+ assertThat(keyArray.getString(1).toString).isEqualTo("dec2")
+
+ val valueArray = flussMap.valueArray()
+ val dec1 = valueArray.getDecimal(0, 10, 2)
+ assertThat(dec1.toBigDecimal.compareTo(new java.math.BigDecimal("123.45"))).isZero()
+ val dec2 = valueArray.getDecimal(1, 10, 2)
+ assertThat(dec2.toBigDecimal.compareTo(new java.math.BigDecimal("678.90"))).isZero()
+ }
+
+ test("map with timestamp values") {
+ val keys = toUTF8Strings("ts1", "ts2")
+ val values = Array(
+ 1634567890123456L, // microseconds timestamp
+ 1634567891123456L
+ )
+ val sparkMap = new ArrayBasedMapData(ArrayData.toArrayData(keys), ArrayData.toArrayData(values))
+ val mapType = MapType(StringType, org.apache.spark.sql.types.TimestampType)
+ val flussMap = SparkAsFlussMap(sparkMap, mapType)
+
+ assertThat(flussMap.size()).isEqualTo(2)
+ val keyArray = flussMap.keyArray()
+ assertThat(keyArray.getString(0).toString).isEqualTo("ts1")
+ assertThat(keyArray.getString(1).toString).isEqualTo("ts2")
+
+ val valueArray = flussMap.valueArray()
+ val ts1 = valueArray.getTimestampNtz(0, 6)
+ assertThat(ts1.toEpochMicros).isEqualTo(1634567890123456L)
+ val ts2 = valueArray.getTimestampNtz(1, 6)
+ assertThat(ts2.toEpochMicros).isEqualTo(1634567891123456L)
+ }
+}
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussRowTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussRowTest.scala
index 6a5240427d..3cdae4c8b8 100644
--- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussRowTest.scala
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/row/SparkAsFlussRowTest.scala
@@ -21,7 +21,7 @@ import org.apache.fluss.spark.FlussSparkTestBase
import org.apache.fluss.spark.util.TestUtils.SCHEMA
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.UTF8String
import org.assertj.core.api.Assertions.assertThat
@@ -48,13 +48,16 @@ class SparkAsFlussRowTest extends FlussSparkTestBase {
UTF8String.fromString("test"),
Timestamp.valueOf("2025-12-31 10:00:00").getTime * 1000,
new GenericArrayData(Array(11.11f, 22.22f)),
+ ArrayBasedMapData.apply(
+ Array[Any](UTF8String.fromString("k1"), UTF8String.fromString("k2")),
+ Array(111, 222)),
InternalRow.apply(123L, UTF8String.fromString("apache fluss"))
))
row = new SparkAsFlussRow(SCHEMA).replace(data)
}
test("Fluss SparkAsFlussRow") {
- assertThat(row.fieldCount).isEqualTo(13)
+ assertThat(row.fieldCount).isEqualTo(14)
assertThat(row.getBoolean(0)).isEqualTo(true)
assertThat(row.getByte(1)).isEqualTo(1.toByte)
@@ -73,8 +76,14 @@ class SparkAsFlussRowTest extends FlussSparkTestBase {
// test array type
assertThat(row.getArray(11).toFloatArray).containsExactly(Array(11.11f, 22.22f): _*)
+ // test map type
+ assertThat(row.getMap(12).size()).isEqualTo(2)
+ assertThat(row.getMap(12).keyArray().getString(0).toString).isEqualTo("k1")
+ assertThat(row.getMap(12).keyArray().getString(1).toString).isEqualTo("k2")
+ assertThat(row.getMap(12).valueArray().toIntArray).containsExactly(Array(111, 222): _*)
+
// test row type
- val nestedRow = row.getRow(12, 2)
+ val nestedRow = row.getRow(13, 2)
assertThat(nestedRow.getFieldCount).isEqualTo(2)
assertThat(nestedRow.getLong(0)).isEqualTo(123L)
assertThat(nestedRow.getString(1).toString).isEqualTo("apache fluss")
diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/util/TestUtils.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/util/TestUtils.scala
index 704ca2ed50..bb06cd730b 100644
--- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/util/TestUtils.scala
+++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/util/TestUtils.scala
@@ -40,6 +40,7 @@ object TestUtils {
// StructField("date", DateType),
StructField("c_timestamp", TimestampType),
StructField("c_array", ArrayType(FloatType, containsNull = false)),
+ StructField("c_map", MapType(StringType, IntegerType, valueContainsNull = false)),
StructField(
"c_row",
StructType(Seq(StructField("id", LongType), StructField("name", StringType))))