Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Jul 11, 2024
1 parent b29493b commit 835bb3e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,9 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
require(outputColNames.distinct.length == outputColNames.length,
s"Output columns should not be duplicate.")

def extractInputDataType(inputColName: String): Option[DataType] = {
val inputSplits = inputColName.split("\\.")
var dtype: Option[DataType] = Some(schema)
var i = 0
while (i < inputSplits.length && dtype.isDefined) {
val s = inputSplits(i)
dtype = if (dtype.get.isInstanceOf[StructType]) {
val struct = dtype.get.asInstanceOf[StructType]
if (struct.fieldNames.contains(s)) {
Some(struct(s).dataType)
} else None
} else None
i += 1
}

dtype
}

val outputFields = inputColNames.zip(outputColNames).flatMap {
case (inputColName, outputColName) =>
extractInputDataType(inputColName) match {
extractInputDataType(schema, inputColName) match {
case Some(dtype) => Some(
validateAndTransformField(schema, inputColName, dtype, outputColName)
)
Expand All @@ -152,6 +134,24 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
}
StructType(schema.fields ++ outputFields)
}

protected def extractInputDataType(schema: StructType, inputColName: String): Option[DataType] = {
val inputSplits = inputColName.split("\\.")
var dtype: Option[DataType] = Some(schema)
var i = 0
while (i < inputSplits.length && dtype.isDefined) {
val s = inputSplits(i)
dtype = if (dtype.get.isInstanceOf[StructType]) {
val struct = dtype.get.asInstanceOf[StructType]
if (struct.fieldNames.contains(s)) {
Some(struct(s).dataType)
} else None
} else None
i += 1
}

dtype
}
}

/**
Expand Down Expand Up @@ -451,7 +451,7 @@ class StringIndexerModel (
val labelToIndex = labelsToIndexArray(i)
val labels = labelsArray(i)

if (!dataset.schema.fieldNames.contains(inputColName)) {
if (extractInputDataType(dataset.schema, inputColName).isEmpty) {
logWarning(log"Input column ${MDC(LogKeys.COLUMN_NAME, inputColName)} does not exist " +
log"during transformation. Skip StringIndexerModel for this column.")
outputColNames(i) = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
(4, 0.0),
(5, 1.0)
).toDF("id", "labelIndex")

testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows =>
val attr = Attribute.fromStructField(rows.head.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("a", "c", "b"))
assert(rows === expected.collect().toSeq)
}
}

test("StringIndexerUnseen") {
Expand Down

0 comments on commit 835bb3e

Please sign in to comment.