diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java index 4d9a5a516c75..19c943b67747 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java @@ -46,6 +46,7 @@ import org.apache.beam.runners.flink.translation.functions.FlinkReduceFunction; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.types.KvKeySelector; +import org.apache.beam.runners.flink.translation.utils.LargeRecordFilterFunction; import org.apache.beam.runners.flink.translation.wrappers.ImpulseInputFormat; import org.apache.beam.runners.fnexecution.provisioning.JobInfo; import org.apache.beam.runners.fnexecution.wire.WireCoders; @@ -92,6 +93,8 @@ import org.apache.flink.api.java.operators.SingleInputUdfOperator; import org.apache.flink.api.java.operators.UnsortedGrouping; import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A translator that translates bounded portable pipelines into executable Flink pipelines. @@ -119,6 +122,8 @@ public class FlinkBatchPortablePipelineTranslator implements FlinkPortablePipelineTranslator< FlinkBatchPortablePipelineTranslator.BatchTranslationContext> { + private static final Logger LOG = + LoggerFactory.getLogger(FlinkBatchPortablePipelineTranslator.class); /** * Creates a batch translation context. The resulting Flink execution dag will live in a new @@ -206,6 +211,7 @@ public FlinkPipelineOptions getPipelineOptions() { @Override public JobExecutionResult execute(String jobName) throws Exception { + LOG.info("Executing Flink batch job with name: {}", jobName); return getExecutionEnvironment().execute(jobName); } @@ -515,8 +521,14 @@ private static void translateGroupByKey( TypeInformation>>> partialReduceTypeInfo = new CoderTypeInformation<>(outputCoder, context.getPipelineOptions()); + ///////////////////////// BEGIN GLEAN MODIFICATION /////////////////////////////// + LOG.info("Add step to filter large records before GroupBy"); + DataSet>> filteredDataSet = + inputDataSet.filter(new LargeRecordFilterFunction<>()); + Grouping>> inputGrouping = - inputDataSet.groupBy(new KvKeySelector<>(inputElementCoder.getKeyCoder())); + filteredDataSet.groupBy(new KvKeySelector<>(inputElementCoder.getKeyCoder())); + ///////////////////////// END GLEAN MODIFICATION ///////////////////////////////// FlinkPartialReduceFunction, ?> partialReduceFunction = new FlinkPartialReduceFunction<>( diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/LargeRecordFilterFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/LargeRecordFilterFunction.java new file mode 100644 index 000000000000..01fd2f831bbd --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/LargeRecordFilterFunction.java @@ -0,0 +1,45 @@ +package org.apache.beam.runners.flink.translation.utils; + +import java.util.List; +import org.apache.beam.runners.flink.FlinkBatchPortablePipelineTranslator; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.flink.api.common.functions.FilterFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * [Glean] + * FilterFunction that filters out large records based on a size threshold. + */ +public class LargeRecordFilterFunction implements + FilterFunction>> { + private static final Logger LOG = + LoggerFactory.getLogger(LargeRecordFilterFunction.class); + private static final long MAX_RECORD_SIZE = 5000000; // 5 MB + + @Override + public boolean filter(WindowedValue> windowedValue) throws Exception { + KV kv = windowedValue.getValue(); + long size = getObjectSize(kv.getKey()) + getObjectSize(kv.getValue()); + if (size >= MAX_RECORD_SIZE) { + LOG.warn("Dropping large record with size: {}", size); + return false; + } + return true; + } + + /** + * Calculate the size of an object in bytes. + * This is a simplified version for objects used in portability. + */ + private static long getObjectSize(T o) { + if (o instanceof byte[]) { + return ((byte[]) o).length; + } else if(o instanceof List) { + return ((List) o).stream().mapToLong(LargeRecordFilterFunction::getObjectSize).sum(); + } else { + return 0; // for other types, we don't calculate size + } + } +}