Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import org.apache.spark.ml.{ComplexParamsWritable, Estimator, Model}
import org.apache.spark.sql._
import org.apache.spark.sql.types._

import scala.annotation.tailrec
import scala.collection.immutable.HashSet
import scala.language.existentials
import scala.math.min
Expand Down Expand Up @@ -199,7 +200,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams]

private def getSlotNamesWithMetadata(featuresSchema: StructField): Option[Array[String]] = {
if (getSlotNames.nonEmpty) {
Some(getSlotNames)
Some(ensureUniqueFeatureNames(getSlotNames))
} else {
AttributeGroup.fromStructField(featuresSchema).attributes.flatMap(attributes =>
if (attributes.isEmpty) {
Expand All @@ -208,28 +209,86 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams]
val colNames = attributes.indices.map(_.toString).toArray
attributes.foreach(attr =>
attr.index.foreach(index => colNames(index) = attr.name.getOrElse(index.toString)))
Some(colNames)
// Ensure unique feature names to avoid LightGBM error:
// "Feature (Column_) appears more than one time"
// This can occur in Spark 3.5+ due to changes in AttributeGroup metadata handling
Some(ensureUniqueFeatureNames(colNames))
}
)
}
}

/**
* Ensures all feature names are unique by appending indices to duplicates.
* This is necessary because Spark 3.5+ can generate duplicate feature names
* in AttributeGroup metadata, which causes LightGBM to fail with:
* "Feature (Column_) appears more than one time"
*
* @param names The array of feature names that may contain duplicates.
* @return An array with unique feature names.
*/
private def ensureUniqueFeatureNames(names: Array[String]): Array[String] = {
val nameCounts = scala.collection.mutable.Map[String, Int]()
val seenNames = scala.collection.mutable.Set[String]()
val uniqueNames = names.map { name =>
val count = nameCounts.getOrElse(name, 0)
nameCounts(name) = count + 1
if (count > 0) {
// Find a unique suffix using Stream to avoid while loop
val newName = Iterator.from(count)
.map(i => s"${name}_$i")
.find(n => !seenNames.contains(n))
.get // Safe because Iterator.from is infinite
seenNames.add(newName)
newName
} else {
seenNames.add(name)
name
}
}

val duplicates = nameCounts.filter(_._2 > 1).keys.toSeq
if (duplicates.nonEmpty) {
log.warn(s"Duplicate feature names detected and renamed: ${duplicates.mkString(", ")}. " +
"This may occur in Spark 3.5+ due to changes in metadata handling. " +
"Consider setting the 'slotNames' parameter explicitly to avoid this.")
}

uniqueNames
}

private def validateSlotNames(featuresSchema: StructField): Unit = {
val metadata = AttributeGroup.fromStructField(featuresSchema)
if (metadata.attributes.isDefined) {
val slotNamesOpt = getSlotNamesWithMetadata(featuresSchema)
val pattern = new Regex("[\",:\\[\\]{}]")
slotNamesOpt.foreach(slotNames => {
val badSlotNames = slotNames.flatMap(slotName =>
if (pattern.findFirstIn(slotName).isEmpty) None else Option(slotName))
if (!badSlotNames.isEmpty) {
throw new IllegalArgumentException(
s"Invalid slot names detected in features column: ${badSlotNames.mkString(",")}" +
" \n Special characters \" , : \\ [ ] { } will cause unexpected behavior in LGBM unless changed." +
" This error can be fixed by renaming the problematic columns prior to vector assembly.")
val slotNamesOpt = getSlotNamesWithMetadata(featuresSchema)
val pattern = new Regex("[\",:\\[\\]{}]")
slotNamesOpt.foreach(slotNames => {
val badSlotNames = slotNames.flatMap(slotName =>
if (pattern.findFirstIn(slotName).isEmpty) None else Option(slotName))
if (!badSlotNames.isEmpty) {
throw new IllegalArgumentException(
s"Invalid slot names detected in features column: ${badSlotNames.mkString(",")}" +
" \n Special characters \" , : \\ [ ] { } will cause unexpected behavior in LGBM unless changed." +
" This error can be fixed by renaming the problematic columns prior to vector assembly.")
}
})
}

private def shouldRetryWithBulk(error: Throwable): Boolean = {
@tailrec
def hasRetryableMessage(current: Throwable): Boolean = {
if (current == null) {
false
} else {
val message = Option(current.getMessage).getOrElse("").toLowerCase
val duplicateFeatureNames = message.contains("appears more than one time")
val datasetCreateFailure = message.contains("dataset create")
if (duplicateFeatureNames && datasetCreateFailure) {
true
} else {
hasRetryableMessage(current.getCause)
}
})
}
}
hasRetryableMessage(error)
}

/**
Expand Down Expand Up @@ -390,10 +449,27 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams]
* creates a driver thread, and runs mapPartitions on the dataset.
*
* @param dataset The dataset to train on.
* @param batchIndex In running in batch training mode, gets the batch number.
* @return The LightGBM Model from the trained LightGBM Booster.
*/
* @param batchIndex In running in batch training mode, gets the batch number.
* @return The LightGBM Model from the trained LightGBM Booster.
*/
private def trainOneDataBatch(dataset: Dataset[_], batchIndex: Int, batchCount: Int): TrainedModel = {
try {
trainOneDataBatchInternal(dataset, batchIndex, batchCount)
} catch {
case ex if getDataTransferMode == LightGBMConstants.StreamingDataTransferMode && shouldRetryWithBulk(ex) =>
log.warn("Detected duplicate feature names while creating LightGBM dataset in streaming mode. " +
"Retrying this batch with dataTransferMode=bulk.")
val originalMode = getDataTransferMode
setDataTransferMode(LightGBMConstants.BulkDataTransferMode)
try {
trainOneDataBatchInternal(dataset, batchIndex, batchCount)
} finally {
setDataTransferMode(originalMode)
}
}
}

private def trainOneDataBatchInternal(dataset: Dataset[_], batchIndex: Int, batchCount: Int): TrainedModel = {
val measures = new InstrumentationMeasures()
setBatchPerformanceMeasure(batchIndex, measures)

Expand Down Expand Up @@ -422,7 +498,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams]
val (serializedReferenceDataset: Option[Array[Byte]], partitionCounts: Option[Array[Long]]) =
if (isStreamingMode) {
val (referenceDataset, partitionCounts) =
calculateRowStatistics(trainingData, trainParams, numCols, measures)
calculateRowStatistics(trainingData, trainParams, numCols, featuresSchema, measures)

// Save the reference Dataset so it's available to client and other batches
if (getReferenceDataset.isEmpty) {
Expand Down Expand Up @@ -503,12 +579,14 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams]
* @param dataframe The dataset to train on.
* @param trainingParams The training parameters.
* @param numCols The number of feature columns.
* @param featuresSchema The schema of the features column.
* @param measures Instrumentation measures.
* @return The serialized Dataset reference and an array of partition counts.
*/
private def calculateRowStatistics(dataframe: DataFrame,
trainingParams: BaseTrainParams,
numCols: Int,
featuresSchema: StructField,
measures: InstrumentationMeasures): (Array[Byte], Array[Long]) = {
measures.markRowStatisticsStart()

Expand All @@ -523,6 +601,9 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams]
trainingParams.generalParams.categoricalFeatures,
trainingParams.executionParams.numThreads)

// Get feature names to set on the reference dataset (ensures unique names for Spark 3.5+)
val featureNames = getSlotNamesWithMetadata(featuresSchema)

// Either get a reference dataset (as bytes) from params, or calculate it
val precalculatedDataset = getReferenceDataset
val serializedReference = if (precalculatedDataset.nonEmpty) {
Expand All @@ -541,6 +622,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams]
totalNumRows,
numCols,
collectedSampleData,
featureNames,
measures,
log)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,65 +10,80 @@ import org.apache.spark.sql._
import org.slf4j.Logger


// scalastyle:off method.length
object ReferenceDatasetUtils {
def createReferenceDatasetFromSample(datasetParams: String,
featuresCol: String,
numRows: Long,
numCols: Int,
sampledRowData: Array[Row],
featureNames: Option[Array[String]],
measures: InstrumentationMeasures,
log: Logger): Array[Byte] = {
log.info(s"Creating reference training dataset with ${sampledRowData.length} samples and config: $datasetParams")

// Pre-create allocated native pointers so it's easy to clean them up
val datasetVoidPtr = lightgbmlib.voidpp_handle()
val lenPtr = lightgbmlib.new_intp()
val bufferHandlePtr = lightgbmlib.voidpp_handle()

val sampledData = SampledData(sampledRowData.length, numCols)

try {
// create properly formatted sampled data
measures.markSamplingStart()
sampledRowData.zipWithIndex.foreach({case (row, index) => sampledData.pushRow(row, index, featuresCol)})
measures.markSamplingStop()

// Create dataset from samples
// 1. Generate the dataset for features
val datasetVoidPtr = lightgbmlib.voidpp_handle()
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromSampledColumn(
sampledData.getSampleData,
sampledData.getSampleIndices,
numCols,
sampledData.getRowCounts,
sampledData.numRows,
1, // Used for allocation and must be > 0, but we don't use this reference set for data collection
numRows,
datasetParams,
datasetVoidPtr), "Dataset create from samples")


// 2. Serialize the raw dataset to a native buffer
val datasetHandle: SWIGTYPE_p_void = lightgbmlib.voidpp_value(datasetVoidPtr)
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSerializeReferenceToBinary(
datasetHandle,
bufferHandlePtr,
lenPtr), "Serialize ref")
val bufferLen: Int = lightgbmlib.intp_value(lenPtr)
log.info(s"Created serialized reference dataset of length $bufferLen")

// The dataset is now serialized to a buffer, so we don't need original
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetFree(datasetHandle), "Free Dataset")

// This will also free the buffer
toByteArray(bufferHandlePtr, bufferLen)
}
finally {
val datasetHandle = createDatasetFromSamples(sampledData, numCols, numRows, datasetParams)
setFeatureNamesIfProvided(datasetHandle, featureNames, numCols, log)
serializeAndCleanup(datasetHandle, bufferHandlePtr, lenPtr, log)
} finally {
sampledData.delete()
lightgbmlib.delete_voidpp(datasetVoidPtr)
lightgbmlib.delete_voidpp(bufferHandlePtr)
lightgbmlib.delete_intp(lenPtr)
}
}
// scalastyle:on method.length

private def createDatasetFromSamples(sampledData: SampledData,
numCols: Int,
numRows: Long,
datasetParams: String): SWIGTYPE_p_void = {
val datasetVoidPtr = lightgbmlib.voidpp_handle()
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromSampledColumn(
sampledData.getSampleData,
sampledData.getSampleIndices,
numCols,
sampledData.getRowCounts,
sampledData.numRows,
1, // Used for allocation and must be > 0, but we don't use this reference set for data collection
numRows,
datasetParams,
datasetVoidPtr), "Dataset create from samples")
lightgbmlib.voidpp_value(datasetVoidPtr)
}

private def setFeatureNamesIfProvided(datasetHandle: SWIGTYPE_p_void,
featureNames: Option[Array[String]],
numCols: Int,
log: Logger): Unit = {
featureNames.foreach { names =>
if (names.nonEmpty) {
log.info(s"Setting ${names.length} feature names on reference dataset")
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSetFeatureNames(datasetHandle, names, numCols),
"Dataset set feature names")
}
}
}

private def serializeAndCleanup(datasetHandle: SWIGTYPE_p_void,
bufferHandlePtr: SWIGTYPE_p_p_void,
lenPtr: SWIGTYPE_p_int,
log: Logger): Array[Byte] = {
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSerializeReferenceToBinary(
datasetHandle, bufferHandlePtr, lenPtr), "Serialize ref")
val bufferLen: Int = lightgbmlib.intp_value(lenPtr)
log.info(s"Created serialized reference dataset of length $bufferLen")
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetFree(datasetHandle), "Free Dataset")
toByteArray(bufferHandlePtr, bufferLen)
}

def getInitializedReferenceDataset(ctx: PartitionTaskContext): LightGBMDataset = {
// The definition is broadcast from Spark, so retrieve it
Expand Down
Loading
Loading