diff --git a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/NexmarkSourceFactory.java b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/NexmarkSourceFactory.java index 6533935bc75..0767cd38299 100644 --- a/gluten-flink/planner/src/main/java/org/apache/gluten/velox/NexmarkSourceFactory.java +++ b/gluten-flink/planner/src/main/java/org/apache/gluten/velox/NexmarkSourceFactory.java @@ -24,6 +24,7 @@ import io.github.zhztheplayer.velox4j.connector.NexmarkConnectorSplit; import io.github.zhztheplayer.velox4j.connector.NexmarkGeneratorConfig; +import io.github.zhztheplayer.velox4j.connector.NexmarkParallelSplit; import io.github.zhztheplayer.velox4j.connector.NexmarkTableHandle; import io.github.zhztheplayer.velox4j.plan.PlanNode; import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; @@ -43,6 +44,7 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -81,22 +83,29 @@ public Transformation buildVeloxSource( "getSplits", new Class[] {int.class}, new Object[] {transformation.getParallelism()}); - Object nexmarkSourceSplit = nexmarkSourceSplits.get(0); - Object generatorConfig = - ReflectUtils.getObjectField( - nexmarkSourceSplit.getClass(), nexmarkSourceSplit, "generatorConfig"); + + // Convert each subtask's NexmarkGeneratorConfig to velox4j + List subtaskSplits = new ArrayList<>(); + for (Object nexmarkSourceSplit : nexmarkSourceSplits) { + Object generatorConfig = + ReflectUtils.getObjectField( + nexmarkSourceSplit.getClass(), nexmarkSourceSplit, "generatorConfig"); + subtaskSplits.add( + new NexmarkConnectorSplit( + "connector-nexmark", toVeloxNexmarkGeneratorConfig(generatorConfig))); + } + PlanNode tableScan = new TableScanNode(id, outputType, new NexmarkTableHandle("connector-nexmark"), List.of()); + NexmarkParallelSplit split = new NexmarkParallelSplit("connector-nexmark", subtaskSplits); GlutenStreamSource sourceOp = new GlutenStreamSource( new GlutenSourceFunction( new StatefulPlanNode(tableScan.getId(), tableScan), Map.of(id, outputType), id, - new NexmarkConnectorSplit( - "connector-nexmark", toVeloxNexmarkGeneratorConfig(generatorConfig)), + split, RowData.class)); - return new LegacySourceTransformation( transformation.getName(), sourceOp, @@ -112,6 +121,7 @@ public Transformation buildVeloxSink( throw new UnsupportedOperationException("Unimplemented method 'buildSink'"); } + /** Convert Flink nexmark NexmarkGeneratorConfig to velox4j NexmarkGeneratorConfig via Jackson. */ private static NexmarkGeneratorConfig toVeloxNexmarkGeneratorConfig(Object javaConfig) { try { String json = MAPPER.writeValueAsString(javaConfig); diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java index 53f36fcf67c..76cced93f15 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java @@ -22,6 +22,7 @@ import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; import io.github.zhztheplayer.velox4j.connector.ConnectorSplit; +import io.github.zhztheplayer.velox4j.connector.ParallelSplit; import io.github.zhztheplayer.velox4j.iterator.UpIterator; import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; import io.github.zhztheplayer.velox4j.query.Query; @@ -230,6 +231,14 @@ private void initSession() { if (sessionResource != null) { return; } + + ConnectorSplit activeSplit = split; + int totalParallelism = getRuntimeContext().getNumberOfParallelSubtasks(); + int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); + if (split instanceof ParallelSplit) { + activeSplit = ((ParallelSplit) split).getSubtaskSplit(subtaskIndex, totalParallelism); + } + sessionResource = new GlutenSessionResource(); GlutenSessionResources.getInstance().addSessionResource(id, sessionResource); Session session = sessionResource.getSession(); @@ -239,7 +248,7 @@ private void initSession() { VeloxQueryConfig.getConfig(getRuntimeContext()), VeloxConnectorConfig.getConfig(getRuntimeContext())); task = session.queryOps().execute(query); - task.addSplit(id, split); + task.addSplit(id, activeSplit); task.noMoreSplits(id); taskMetrics = new SourceTaskMetrics(getRuntimeContext().getMetricGroup()); } diff --git a/gluten-flink/ut/src/test/java/org/apache/gluten/velox/NexmarkSourceFactoryTest.java b/gluten-flink/ut/src/test/java/org/apache/gluten/velox/NexmarkSourceFactoryTest.java new file mode 100644 index 00000000000..8f149cfb73d --- /dev/null +++ b/gluten-flink/ut/src/test/java/org/apache/gluten/velox/NexmarkSourceFactoryTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.velox; + +import org.apache.gluten.streaming.api.operators.GlutenStreamSource; + +import io.github.zhztheplayer.velox4j.connector.ConnectorSplit; +import io.github.zhztheplayer.velox4j.connector.NexmarkConnectorSplit; +import io.github.zhztheplayer.velox4j.connector.NexmarkGeneratorConfig; +import io.github.zhztheplayer.velox4j.connector.NexmarkParallelSplit; + +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.transformations.LegacySourceTransformation; +import org.apache.flink.streaming.api.transformations.SourceTransformation; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; + +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Constructor; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class NexmarkSourceFactoryTest { + + private static final String NEXMARK_SOURCE_CN = "com.github.nexmark.flink.source.NexmarkSource"; + private static final String NEXMARK_CONFIG_CN = "com.github.nexmark.flink.NexmarkConfiguration"; + private static final String GENERATOR_CONFIG_CN = + "com.github.nexmark.flink.generator.GeneratorConfig"; + + @SuppressWarnings("rawtypes") + @Test + void testBuildVeloxSourceWrapsSplitsInNexmarkParallelSplit() throws Exception { + SourceTransformation tx = newSourceTransformation(/* parallelism= */ 2); + + NexmarkSourceFactory factory = new NexmarkSourceFactory(); + Transformation result = factory.buildVeloxSource(tx, Collections.emptyMap()); + + LegacySourceTransformation legacy = + assertInstanceOf(LegacySourceTransformation.class, result); + GlutenStreamSource streamSource = + assertInstanceOf(GlutenStreamSource.class, legacy.getOperator()); + + ConnectorSplit split = streamSource.getConnectorSplit(); + NexmarkParallelSplit parallel = assertInstanceOf(NexmarkParallelSplit.class, split); + + NexmarkConnectorSplit s0 = + assertInstanceOf(NexmarkConnectorSplit.class, parallel.getSubtaskSplit(0, 2)); + NexmarkConnectorSplit s1 = + assertInstanceOf(NexmarkConnectorSplit.class, parallel.getSubtaskSplit(1, 2)); + + NexmarkGeneratorConfig c0 = s0.getConfig(); + NexmarkGeneratorConfig c1 = s1.getConfig(); + assertEquals(0L, c0.getFirstEventId()); + assertEquals(500L, c0.getMaxEventsOrZero()); + assertEquals(500L, c1.getFirstEventId()); + assertEquals(500L, c1.getMaxEventsOrZero()); + } + + @SuppressWarnings("rawtypes") + @Test + void testBuildVeloxSourceAtParallelismOneStillProducesParallelSplit() throws Exception { + SourceTransformation tx = newSourceTransformation(/* parallelism= */ 1); + + NexmarkSourceFactory factory = new NexmarkSourceFactory(); + Transformation result = factory.buildVeloxSource(tx, Collections.emptyMap()); + + LegacySourceTransformation legacy = + assertInstanceOf(LegacySourceTransformation.class, result); + GlutenStreamSource streamSource = + assertInstanceOf(GlutenStreamSource.class, legacy.getOperator()); + + NexmarkParallelSplit parallel = + assertInstanceOf(NexmarkParallelSplit.class, streamSource.getConnectorSplit()); + NexmarkConnectorSplit s0 = + assertInstanceOf(NexmarkConnectorSplit.class, parallel.getSubtaskSplit(0, 1)); + + assertEquals(0L, s0.getConfig().getFirstEventId()); + assertEquals(1000L, s0.getConfig().getMaxEventsOrZero()); + } + + @Test + void testBuildVeloxSourceRejectsNonSourceTransformation() { + NexmarkSourceFactory factory = new NexmarkSourceFactory(); + assertThrows( + ClassCastException.class, + () -> factory.buildVeloxSource(new StubTransformation(), Collections.emptyMap())); + } + + @SuppressWarnings("rawtypes") + private static SourceTransformation newSourceTransformation(int parallelism) throws Exception { + Object nexmarkSource = newNexmarkSource(1000L); + Constructor ctor = + SourceTransformation.class.getDeclaredConstructor( + String.class, + org.apache.flink.api.connector.source.Source.class, + org.apache.flink.api.common.eventtime.WatermarkStrategy.class, + org.apache.flink.api.common.typeinfo.TypeInformation.class, + int.class); + return (SourceTransformation) + ctor.newInstance( + "nexmark-source", + nexmarkSource, + org.apache.flink.api.common.eventtime.WatermarkStrategy.noWatermarks(), + InternalTypeInfo.of(RowType.of(new IntType())), + parallelism); + } + + private static Object newNexmarkSource(long maxEvents) throws Exception { + Object nexmarkConfig = Class.forName(NEXMARK_CONFIG_CN).getDeclaredConstructor().newInstance(); + java.lang.reflect.Field numEvents = nexmarkConfig.getClass().getDeclaredField("numEvents"); + numEvents.setAccessible(true); + numEvents.setLong(nexmarkConfig, maxEvents); + + Class generatorConfigCls = Class.forName(GENERATOR_CONFIG_CN); + Constructor generatorConfigCtor = + generatorConfigCls.getDeclaredConstructor( + Class.forName(NEXMARK_CONFIG_CN), + long.class, + long.class, + long.class, + long.class, + long.class); + Object generatorConfig = + generatorConfigCtor.newInstance(nexmarkConfig, 0L, 0L, maxEvents, maxEvents, 0L); + + Class nexmarkSourceCls = Class.forName(NEXMARK_SOURCE_CN); + Constructor nexmarkSourceCtor = + nexmarkSourceCls.getDeclaredConstructor( + generatorConfigCls, org.apache.flink.api.common.typeinfo.TypeInformation.class); + nexmarkSourceCtor.setAccessible(true); + return nexmarkSourceCtor.newInstance( + generatorConfig, InternalTypeInfo.of(RowType.of(new IntType()))); + } + + private static final class StubTransformation extends Transformation { + StubTransformation() { + super("stub", InternalTypeInfo.of(RowType.of(new IntType())), 1); + } + + @Override + public java.util.List> getInputs() { + return Collections.emptyList(); + } + + @Override + protected java.util.List> getTransitivePredecessorsInternal() { + return Collections.emptyList(); + } + } +}