Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Nov 4, 2024
1 parent c53dac0 commit ab3a6a4
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 5 deletions.
3 changes: 2 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -103,7 +104,7 @@ abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, T <: UnaryTransformer
protected def validateInputType(inputType: DataType): Unit = {}

override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
val inputType = SchemaUtils.getSchemaFieldType(schema, $(inputCol))
validateInputType(inputType)
if (schema.fieldNames.contains($(outputCol))) {
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
case DoubleType =>
BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
case _: VectorUDT =>
val size = AttributeGroup.fromStructField(schema(inputColName)).size
val size = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(schema, inputColName)
).size
if (size < 0) {
StructField(outputColName, new VectorUDT)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class HashingTF @Since("3.0.0") private[ml] (

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
val inputType = SchemaUtils.getSchemaFieldType(schema, $(inputCol))
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ${ArrayType.simpleString}, but got ${inputType.catalogString}.")
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ class Normalizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
val size = AttributeGroup.fromStructField(schema($(inputCol))).size
val size = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(schema, $(inputCol))
).size
if (size >= 0) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(outputCol), size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private[spark] object SchemaUtils {
colName: String,
dataType: DataType,
msg: String = ""): Unit = {
val actualDataType = schema(colName).dataType
val actualDataType = SchemaUtils.getSchemaField(schema, colName).dataType
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.equals(dataType),
s"Column $colName must be of type ${dataType.getClass}:${dataType.catalogString} " +
Expand Down

0 comments on commit ab3a6a4

Please sign in to comment.