Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Distinct Aggregation #1232

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,72 +21,37 @@ class DistinctAggregate(val metricCol: String) extends AggregationSpec {

override def isCalculateAggregateNeeded: Boolean = true


override def calculateAggregate(aggregate: Any, dataType: DataType): Any = {
if (aggregate == null) {
aggregate
} else {
dataType match {
case ArrayType(IntegerType, false) => aggregate
case ArrayType(LongType, false) => aggregate
case ArrayType(DoubleType, false) => aggregate
case ArrayType(FloatType, false) => aggregate
case ArrayType(StringType, false) => aggregate
case _ => throw new RuntimeException(s"Invalid data type for DISTINCT metric col $metricCol. " +
s"Only Array[Int], Array[Long], Array[Double], Array[Float] and Array[String] are supported, but got ${dataType.typeName}")
val result = dataType match {
case IntegerType => aggregate.asInstanceOf[Set[Int]]
case LongType => aggregate.asInstanceOf[Set[Long]]
case DoubleType => aggregate.asInstanceOf[Set[Double]]
case FloatType => aggregate.asInstanceOf[Set[Float]]
case StringType => aggregate.asInstanceOf[Set[String]]
case _ => throw new RuntimeException(s"Invalid data type for COUNT_DISTINCT metric col $metricCol. " +
s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
}
result.mkString(",")
}
}

/*
Record is what we get from SlidingWindowFeatureUtils. Aggregate is what we return here for first time.
The datatype of both should match. This is a limitation of Feathr
*/
override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
if (aggregate == null) {
val wrappedArray = record.asInstanceOf[mutable.WrappedArray[Int]]
return ArrayBuffer(wrappedArray: _*)
Set(record)
} else if (record == null) {
aggregate
} else {
dataType match {
case ArrayType(IntegerType, false) =>
val set1 = aggregate.asInstanceOf[mutable.ArrayBuffer[Int]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Int]].toArray.toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(LongType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Long]]toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Long]]toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(DoubleType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Double]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Double]].toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(FloatType, false) =>
val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Float]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[Float]].toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case ArrayType(StringType, false) =>
val set1 = aggregate.asInstanceOf[mutable.ArrayBuffer[String]].toSet
val set2 = record.asInstanceOf[mutable.WrappedArray[String]].toArray.toSet
val set3 = set1.union(set2)
val new_aggregate = ArrayBuffer(set3.toSeq: _*)
return new_aggregate

case IntegerType => aggregate.asInstanceOf[Set[Int]] + record.asInstanceOf[Int]
case LongType => aggregate.asInstanceOf[Set[Long]] + record.asInstanceOf[Long]
case DoubleType => aggregate.asInstanceOf[Set[Double]] + record.asInstanceOf[Double]
case FloatType => aggregate.asInstanceOf[Set[Float]] + record.asInstanceOf[Float]
case StringType=> aggregate.asInstanceOf[Set[String]] + record.asInstanceOf[String]
case _ => throw new RuntimeException(s"Invalid data type for DISTINCT metric col $metricCol. " +
s"Only Array[Int], Array[Long], Array[Double], Array[Float] and Array[String] are supported, but got ${dataType.typeName}")
s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1888,180 +1888,5 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {

validateRows(dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField)), expectedRows)
}

@Test
def testSWADistinctIntegers(): Unit = {
val featureDefAsString =
"""
|sources: {
| swaSource: {
| location: { path: "generation/daily/" }
| isTimeSeries: true
| timeWindowParameters: {
| timestampColumn: "timestamp"
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|}
|anchors: {
| swaAnchorWithKeyExtractor: {
| source: "swaSource"
| key: [x]
| features: {
| g: {
| def: "Id" // the column that contains the raw view count
| aggregation: DISTINCT
| window: 10d
| }
| }
| }
|}
""".stripMargin

val features = Seq("g")
val keyField = "x"
val featureJoinAsString =
s"""
| settings: {
| joinTimeSettings: {
| timestampColumn: {
| def: timestamp
| format: yyyy-MM-dd
| }
| }
|}
|features: [
| {
| key: [$keyField],
| featureList: [${features.mkString(",")}]
| }
|]
""".stripMargin


/**
* Expected output:
* +--------+----+----+
* |x| f| g|
* +--------+----+----+
* | 1| 6| 2|
* | 2| 5| 2|
* | 3| 1| 1|
* +--------+----+----+
*/
val expectedSchema = StructType(
Seq(
StructField(keyField, LongType),
StructField(features.last, ArrayType(FloatType, false))
))
import scala.collection.mutable.WrappedArray
val expectedRows = Array(
new GenericRowWithSchema(Array(1, Array(10.0f, 11.0f)), expectedSchema),
new GenericRowWithSchema(Array(2, Array(10.0f, 11.0f)), expectedSchema),
new GenericRowWithSchema(Array(3, Array(9.0f)), expectedSchema)
)

val dfs = runLocalFeatureJoinForTest(featureJoinAsString, featureDefAsString, "featuresWithFilterObs.avro.json").data
val result = dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField))
val actualRows = result.map(row => Row(row.get(0), row.get(1).asInstanceOf[WrappedArray[Float]].toArray))
validateComplexRows(actualRows, expectedRows)

}

@Test
def testSWADistinctStrings(): Unit = {
val featureDefAsString =
"""
|sources: {
| swaSource: {
| location: { path: "generation/daily/" }
| isTimeSeries: true
| timeWindowParameters: {
| timestampColumn: "timestamp"
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|}
|anchors: {
| swaAnchorWithKeyExtractor: {
| source: "swaSource"
| key: [x]
| features: {
| f: {
| def: "user" // the column that contains the user as string
| aggregation: DISTINCT
| window: 10d
| }
| }
| }
|}
""".stripMargin

val features = Seq("f")
val keyField = "x"
val featureJoinAsString =
s"""
| settings: {
| joinTimeSettings: {
| timestampColumn: {
| def: timestamp
| format: yyyy-MM-dd
| }
| }
|}
|features: [
| {
| key: [$keyField],
| featureList: [${features.mkString(",")}]
| }
|]
""".stripMargin


/**
* Expected output:
* +--------+----+----+
* |x| f| g|
* +--------+----+----+
* | 1| 6| 2|
* | 2| 5| 2|
* | 3| 1| 1|
* +--------+----+----+
*/
val expectedSchema = StructType(
Seq(
StructField(keyField, LongType),
StructField(features.last, ArrayType(StringType, false))
))
import scala.collection.mutable.WrappedArray
val expectedRows = Array(
new GenericRowWithSchema(Array(1, Array("user10", "user11")), expectedSchema),
new GenericRowWithSchema(Array(2, Array("user10", "user11")), expectedSchema),
new GenericRowWithSchema(Array(3, Array("user9")), expectedSchema)
)
val dfs = runLocalFeatureJoinForTest(featureJoinAsString, featureDefAsString, "featuresWithFilterObs.avro.json").data
val result = dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField))
//val col1 = result.map(row => Row(row.get(0)))
val col2_initial = result.map(row => Row(row.get(1)))

val actualRows = new Array[Row](3);
var count = 1

for (row <- col2_initial) {
val genericRow = row.get(0);
genericRow.getClass;
val genericRow2 = genericRow.asInstanceOf[GenericRowWithSchema];
val val1 = genericRow2.get(0);
val val2 = val1.asInstanceOf[WrappedArray[String]].toArray
print(val1)
val resultRow = Row(count, val2)
actualRows(count - 1) = resultRow
count += 1
}

validateComplexRows(actualRows, expectedRows)

}

}

1 change: 1 addition & 0 deletions feathr_project/feathr/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Dataset utilities
"""

import logging
import math
from pathlib import Path
Expand Down
Loading
Loading