From 835bb3e00a0b45a94630df9d4732660bad765c64 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 11 Jul 2024 10:14:54 +0800 Subject: [PATCH] update Signed-off-by: Weichen Xu --- .../spark/ml/feature/StringIndexer.scala | 40 +++++++++---------- .../spark/ml/feature/StringIndexerSuite.scala | 7 ++++ 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9e33c38cdaf4f..63a3a67e90cba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -122,27 +122,9 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi require(outputColNames.distinct.length == outputColNames.length, s"Output columns should not be duplicate.") - def extractInputDataType(inputColName: String): Option[DataType] = { - val inputSplits = inputColName.split("\\.") - var dtype: Option[DataType] = Some(schema) - var i = 0 - while (i < inputSplits.length && dtype.isDefined) { - val s = inputSplits(i) - dtype = if (dtype.get.isInstanceOf[StructType]) { - val struct = dtype.get.asInstanceOf[StructType] - if (struct.fieldNames.contains(s)) { - Some(struct(s).dataType) - } else None - } else None - i += 1 - } - - dtype - } - val outputFields = inputColNames.zip(outputColNames).flatMap { case (inputColName, outputColName) => - extractInputDataType(inputColName) match { + extractInputDataType(schema, inputColName) match { case Some(dtype) => Some( validateAndTransformField(schema, inputColName, dtype, outputColName) ) @@ -152,6 +134,24 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi } StructType(schema.fields ++ outputFields) } + + protected def extractInputDataType(schema: StructType, inputColName: String): Option[DataType] = { + val inputSplits = inputColName.split("\\.") + var dtype: Option[DataType] = Some(schema) + var i = 0 + while (i < inputSplits.length && dtype.isDefined) { + val s = inputSplits(i) + dtype = if (dtype.get.isInstanceOf[StructType]) { + val struct = dtype.get.asInstanceOf[StructType] + if (struct.fieldNames.contains(s)) { + Some(struct(s).dataType) + } else None + } else None + i += 1 + } + + dtype + } } /** @@ -451,7 +451,7 @@ class StringIndexerModel ( val labelToIndex = labelsToIndexArray(i) val labels = labelsArray(i) - if (!dataset.schema.fieldNames.contains(inputColName)) { + if (extractInputDataType(dataset.schema, inputColName).isEmpty) { logWarning(log"Input column ${MDC(LogKeys.COLUMN_NAME, inputColName)} does not exist " + log"during transformation. Skip StringIndexerModel for this column.") outputColNames(i) = null diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b913d88e80a79..fc3d2d349ab06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -149,6 +149,13 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { (4, 0.0), (5, 1.0) ).toDF("id", "labelIndex") + + testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + assert(rows === expected.collect().toSeq) + } } test("StringIndexerUnseen") {