Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.avro.AvroOptions
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
Expand All @@ -31,6 +32,7 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.collection.BitSet

case class AvroScan(
sparkSession: SparkSession,
Expand All @@ -41,7 +43,13 @@ case class AvroScan(
options: CaseInsensitiveStringMap,
pushedFilters: Array[Filter],
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
dataFilters: Seq[Expression] = Seq.empty,
override val bucketSpec: Option[BucketSpec] = None,
override val disableBucketedScan: Boolean = false,
override val optionalBucketSet: Option[BitSet] = None,
override val optionalNumCoalescedBuckets: Option[Int] = None,
override val requestedMetadataFields: StructType = StructType(Seq.empty))
extends FileScan {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
Expand All @@ -52,14 +60,15 @@ case class AvroScan(
val parsedOptions = new AvroOptions(caseSensitiveMap, hadoopConf)
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
AvroPartitionReaderFactory(
val baseFactory = AvroPartitionReaderFactory(
conf,
broadcastedConf,
dataSchema,
readDataSchema,
readPartitionSchema,
parsedOptions,
pushedFilters.toImmutableArraySeq)
wrapWithMetadataIfNeeded(baseFactory, readDataSchema, options)
}

override def equals(obj: Any): Boolean = obj match {
Expand All @@ -70,6 +79,15 @@ case class AvroScan(

override def hashCode(): Int = super.hashCode()

override def withFileIndex(newFI: PartitioningAwareFileIndex): AvroScan =
copy(fileIndex = newFI)

override def withDisableBucketedScan(disable: Boolean): AvroScan =
copy(disableBucketedScan = disable)

override def withNumCoalescedBuckets(n: Option[Int]): AvroScan =
copy(optionalNumCoalescedBuckets = n)

override def getMetaData(): Map[String, String] = {
super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters.toImmutableArraySeq))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.v2.avro

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.StructFilters
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
Expand All @@ -29,10 +30,12 @@ case class AvroScanBuilder (
fileIndex: PartitioningAwareFileIndex,
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
options: CaseInsensitiveStringMap,
override val bucketSpec: Option[BucketSpec] = None)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema, bucketSpec) {

override def build(): AvroScan = {
val optBucketSet = computeBucketSet()
AvroScan(
sparkSession,
fileIndex,
Expand All @@ -42,7 +45,10 @@ case class AvroScanBuilder (
options,
pushedDataFilters,
partitionFilters,
dataFilters)
dataFilters,
bucketSpec = bucketSpec,
optionalBucketSet = optBucketSet,
requestedMetadataFields = requestedMetadataFields)
}

override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.avro.AvroUtils
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -37,19 +37,21 @@ case class AvroTable(
fallbackFileFormat: Class[_ <: FileFormat])
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
override def newScanBuilder(options: CaseInsensitiveStringMap): AvroScanBuilder =
AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options))
AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options),
bucketSpec)

override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files)

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder {
override def build(): Write =
AvroWrite(paths, formatName, supportsDataType, mergedWriteInfo(info))
createFileWriteBuilder(info) {
(mergedInfo, partSchema, bSpec, customLocs, dynamicOverwrite, truncate, overPreds) =>
AvroWrite(paths, formatName, supportsDataType, mergedInfo, partSchema, bSpec,
overPreds, customLocs, dynamicOverwrite, truncate)
}
}

override def supportsDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType)

override def formatName: String = "AVRO"
override def formatName: String = "Avro"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.v2.avro
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.sql.avro.AvroUtils
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.write.LogicalWriteInfo
import org.apache.spark.sql.execution.datasources.OutputWriterFactory
import org.apache.spark.sql.execution.datasources.v2.FileWrite
Expand All @@ -29,7 +31,13 @@ case class AvroWrite(
paths: Seq[String],
formatName: String,
supportsDataType: DataType => Boolean,
info: LogicalWriteInfo) extends FileWrite {
info: LogicalWriteInfo,
partitionSchema: StructType,
override val bucketSpec: Option[BucketSpec] = None,
override val overwritePredicates: Option[Array[Predicate]] = None,
override val customPartitionLocations: Map[Map[String, String], String] = Map.empty,
override val dynamicPartitionOverwrite: Boolean,
override val isTruncate: Boolean) extends FileWrite {
override def prepareWrite(
sqlConf: SQLConf,
job: Job,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1159,14 +1159,16 @@ class Analyzer(
throw QueryCompilationErrors.unsupportedInsertReplaceOnOrUsing(
i.table.asInstanceOf[DataSourceV2Relation].table.name())

case i: InsertIntoStatement
if i.table.isInstanceOf[DataSourceV2Relation] &&
i.query.resolved &&
i.replaceCriteriaOpt.isEmpty =>
val r = i.table.asInstanceOf[DataSourceV2Relation]
// ifPartitionNotExists is append with validation, but validation is not supported
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _, _, _, _)
if i.query.resolved && i.replaceCriteriaOpt.isEmpty =>
// SPARK-56304: allow ifPartitionNotExists for tables that
// support partition management and overwrite-by-filter
if (i.ifPartitionNotExists) {
throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name)
val caps = r.table.capabilities
if (!caps.contains(TableCapability.OVERWRITE_BY_FILTER) ||
!r.table.isInstanceOf[SupportsPartitionManagement]) {
throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name)
}
}

// Create a project if this is an INSERT INTO BY NAME query.
Expand Down Expand Up @@ -1209,17 +1211,27 @@ class Analyzer(
withSchemaEvolution = i.withSchemaEvolution)
}
} else {
val extraOpts = if (i.ifPartitionNotExists) {
Map("ifPartitionNotExists" -> "true") ++
staticPartitions.map { case (k, v) =>
s"__staticPartition.$k" -> v
}
} else {
Map.empty[String, String]
}
if (isByName) {
OverwriteByExpression.byName(
table = r,
df = query,
deleteExpr = staticDeleteExpression(r, staticPartitions),
writeOptions = extraOpts,
withSchemaEvolution = i.withSchemaEvolution)
} else {
OverwriteByExpression.byPosition(
table = r,
query = query,
deleteExpr = staticDeleteExpression(r, staticPartitions),
writeOptions = extraOpts,
withSchemaEvolution = i.withSchemaEvolution)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2;

import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

/**
* A struct-typed {@link ColumnVector} backed by a fixed array of arbitrary child vectors.
* Composes the V2 {@code _metadata} struct column from per-file
* {@link org.apache.spark.sql.execution.vectorized.ConstantColumnVector}s (for constant
* fields like {@code file_path}) and per-row child vectors supplied by the format reader
* (e.g., Parquet's {@code _tmp_metadata_row_index}).
*
* <p>Intentionally minimal: only {@link #getChild(int)}, {@link #isNullAt(int)}, and
* {@link #close()} carry behavior. The parent {@link ColumnVector#getStruct(int)} routes
* struct field access through {@code getChild}, so scalar getters are never called and
* throw if invoked.
*
* <p>Children are owned by their producers (the input batch or the metadata wrapper); this
* class does not close them.
*/
final class CompositeStructColumnVector extends ColumnVector {

private final ColumnVector[] children;

CompositeStructColumnVector(StructType type, ColumnVector[] children) {
super(type);
if (children.length != type.fields().length) {
throw new IllegalArgumentException(
"Children count " + children.length + " does not match struct field count "
+ type.fields().length);
}
for (int i = 0; i < children.length; i++) {
if (children[i] == null) {
throw new IllegalArgumentException("Child column vector at index " + i + " is null");
}
}
this.children = children;
}

@Override
public void close() {
// Children are owned by the underlying batch / extractor; do not close them here.
}

@Override
public boolean hasNull() {
return false;
}

@Override
public int numNulls() {
return 0;
}

@Override
public boolean isNullAt(int rowId) {
return false;
}

@Override
public ColumnVector getChild(int ordinal) {
return children[ordinal];
}

// Scalar accessors are unreachable for a struct vector and exist only to satisfy the
// abstract base class contract.
@Override
public boolean getBoolean(int rowId) { throw unsupported("getBoolean"); }
@Override
public byte getByte(int rowId) { throw unsupported("getByte"); }
@Override
public short getShort(int rowId) { throw unsupported("getShort"); }
@Override
public int getInt(int rowId) { throw unsupported("getInt"); }
@Override
public long getLong(int rowId) { throw unsupported("getLong"); }
@Override
public float getFloat(int rowId) { throw unsupported("getFloat"); }
@Override
public double getDouble(int rowId) { throw unsupported("getDouble"); }
@Override
public ColumnarArray getArray(int rowId) { throw unsupported("getArray"); }
@Override
public ColumnarMap getMap(int ordinal) { throw unsupported("getMap"); }
@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
throw unsupported("getDecimal");
}
@Override
public UTF8String getUTF8String(int rowId) { throw unsupported("getUTF8String"); }
@Override
public byte[] getBinary(int rowId) { throw unsupported("getBinary"); }

private UnsupportedOperationException unsupported(String method) {
return new UnsupportedOperationException(
method + " is not supported on " + getClass().getSimpleName()
+ "; access struct fields via getChild()");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
/** Resolves the relations created from the DataFrameReader and DataStreamReader APIs. */
class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {

/**
* Returns true if the provider is a FileDataSourceV2 AND the provider's short name or
* class name is in USE_V1_SOURCE_LIST. When true, streaming falls back to V1.
*/
private def isV1FileSource(provider: TableProvider): Boolean = {
provider.isInstanceOf[FileDataSourceV2] && {
val v1Sources = sparkSession.sessionState.conf
.getConf(org.apache.spark.sql.internal.SQLConf.USE_V1_SOURCE_LIST)
.toLowerCase(Locale.ROOT).split(",").map(_.trim)
val shortName = provider match {
case d: org.apache.spark.sql.sources.DataSourceRegister => d.shortName()
case _ => ""
}
v1Sources.contains(shortName.toLowerCase(Locale.ROOT)) ||
v1Sources.contains(provider.getClass.getCanonicalName.toLowerCase(Locale.ROOT))
}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case UnresolvedDataSource(source, userSpecifiedSchema, extraOptions, false, paths) =>
// Batch data source created from DataFrameReader
Expand Down Expand Up @@ -92,8 +110,7 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
case _ => None
}
ds match {
// file source v2 does not support streaming yet.
case provider: TableProvider if !provider.isInstanceOf[FileDataSourceV2] =>
case provider: TableProvider if !isV1FileSource(provider) =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
source = provider, conf = sparkSession.sessionState.conf)
val finalOptions =
Expand Down Expand Up @@ -153,8 +170,7 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
case _ => None
}
ds match {
// file source v2 does not support streaming yet.
case provider: TableProvider if !provider.isInstanceOf[FileDataSourceV2] =>
case provider: TableProvider if !isV1FileSource(provider) =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
source = provider, conf = sparkSession.sessionState.conf)
val finalOptions =
Expand Down
Loading