From 9e98ae3355a839b253db5248762212a296c83ac2 Mon Sep 17 00:00:00 2001
From: Andrew Tang <andrew.tang@warnermedia.com>
Date: Thu, 4 Apr 2024 12:12:23 -0700
Subject: [PATCH 1/2] fix distinct aggregation bug

---
 .../swj/aggregate/DistinctAggregate.scala     |  69 ++-----
 .../offline/SlidingWindowAggIntegTest.scala   | 175 ------------------
 2 files changed, 17 insertions(+), 227 deletions(-)

diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/swj/aggregate/DistinctAggregate.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/swj/aggregate/DistinctAggregate.scala
index 2d6bb53a6..57e77ef2a 100644
--- a/feathr-impl/src/main/scala/com/linkedin/feathr/swj/aggregate/DistinctAggregate.scala
+++ b/feathr-impl/src/main/scala/com/linkedin/feathr/swj/aggregate/DistinctAggregate.scala
@@ -21,72 +21,37 @@ class DistinctAggregate(val metricCol: String) extends AggregationSpec {
 
   override def isCalculateAggregateNeeded: Boolean = true
 
-
   override def calculateAggregate(aggregate: Any, dataType: DataType): Any = {
     if (aggregate == null) {
       aggregate
     } else {
-      dataType match {
-        case ArrayType(IntegerType, false) => aggregate
-        case ArrayType(LongType, false) => aggregate
-        case ArrayType(DoubleType, false) => aggregate
-        case ArrayType(FloatType, false) => aggregate
-        case ArrayType(StringType, false) => aggregate
-        case _ => throw new RuntimeException(s"Invalid data type for DISTINCT metric col $metricCol. " +
-          s"Only Array[Int], Array[Long], Array[Double], Array[Float] and Array[String] are supported, but got ${dataType.typeName}")
+      val result = dataType match {
+        case IntegerType => aggregate.asInstanceOf[Set[Int]]
+        case LongType => aggregate.asInstanceOf[Set[Long]]
+        case DoubleType => aggregate.asInstanceOf[Set[Double]]
+        case FloatType => aggregate.asInstanceOf[Set[Float]]
+        case StringType => aggregate.asInstanceOf[Set[String]]
+        case _ => throw new RuntimeException(s"Invalid data type for COUNT_DISTINCT metric col $metricCol. " +
+          s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
       }
+      result.mkString(",")
     }
   }
 
-  /*
-    Record is what we get from SlidingWindowFeatureUtils. Aggregate is what we return here for first time.
-    The datatype of both should match. This is a limitation of Feathr
-   */
-   override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
+  override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
     if (aggregate == null) {
-      val wrappedArray = record.asInstanceOf[mutable.WrappedArray[Int]]
-      return ArrayBuffer(wrappedArray: _*)
+      Set(record)
     } else if (record == null) {
       aggregate
     } else {
       dataType match {
-        case ArrayType(IntegerType, false) =>
-          val set1 = aggregate.asInstanceOf[mutable.ArrayBuffer[Int]].toSet
-          val set2 = record.asInstanceOf[mutable.WrappedArray[Int]].toArray.toSet
-          val set3 = set1.union(set2)
-          val new_aggregate = ArrayBuffer(set3.toSeq: _*)
-          return new_aggregate
-
-        case ArrayType(LongType, false) =>
-          val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Long]]toSet
-          val set2 = record.asInstanceOf[mutable.WrappedArray[Long]]toSet
-          val set3 = set1.union(set2)
-          val new_aggregate = ArrayBuffer(set3.toSeq: _*)
-          return new_aggregate
-
-        case ArrayType(DoubleType, false) =>
-          val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Double]].toSet
-          val set2 = record.asInstanceOf[mutable.WrappedArray[Double]].toSet
-          val set3 = set1.union(set2)
-          val new_aggregate = ArrayBuffer(set3.toSeq: _*)
-          return new_aggregate
-
-        case ArrayType(FloatType, false) =>
-          val set1 = aggregate.asInstanceOf[mutable.WrappedArray[Float]].toSet
-          val set2 = record.asInstanceOf[mutable.WrappedArray[Float]].toSet
-          val set3 = set1.union(set2)
-          val new_aggregate = ArrayBuffer(set3.toSeq: _*)
-          return new_aggregate
-
-        case ArrayType(StringType, false) =>
-          val set1 = aggregate.asInstanceOf[mutable.ArrayBuffer[String]].toSet
-          val set2 = record.asInstanceOf[mutable.WrappedArray[String]].toArray.toSet
-          val set3 = set1.union(set2)
-          val new_aggregate = ArrayBuffer(set3.toSeq: _*)
-          return new_aggregate
-
+        case IntegerType => aggregate.asInstanceOf[Set[Int]] + record.asInstanceOf[Int]
+        case LongType => aggregate.asInstanceOf[Set[Long]] + record.asInstanceOf[Long]
+        case DoubleType => aggregate.asInstanceOf[Set[Double]] + record.asInstanceOf[Double]
+        case FloatType => aggregate.asInstanceOf[Set[Float]] + record.asInstanceOf[Float]
+        case StringType=> aggregate.asInstanceOf[Set[String]] + record.asInstanceOf[String]
         case _ => throw new RuntimeException(s"Invalid data type for DISTINCT metric col $metricCol. " +
-          s"Only Array[Int], Array[Long], Array[Double], Array[Float] and Array[String] are supported, but got ${dataType.typeName}")
+          s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
       }
     }
   }
diff --git a/feathr-impl/src/test/scala/com/linkedin/feathr/offline/SlidingWindowAggIntegTest.scala b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/SlidingWindowAggIntegTest.scala
index a7f7f57be..99b6c4133 100644
--- a/feathr-impl/src/test/scala/com/linkedin/feathr/offline/SlidingWindowAggIntegTest.scala
+++ b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/SlidingWindowAggIntegTest.scala
@@ -1888,180 +1888,5 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {
 
     validateRows(dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField)), expectedRows)
   }
-
-  @Test
-  def testSWADistinctIntegers(): Unit = {
-    val featureDefAsString =
-      """
-        |sources: {
-        |  swaSource: {
-        |    location: { path: "generation/daily/" }
-        |    isTimeSeries: true
-        |    timeWindowParameters: {
-        |      timestampColumn: "timestamp"
-        |      timestampColumnFormat: "yyyy-MM-dd"
-        |    }
-        |  }
-        |}
-        |anchors: {
-        |  swaAnchorWithKeyExtractor: {
-        |    source: "swaSource"
-        |    key: [x]
-        |    features: {
-        |      g: {
-        |        def: "Id"   // the column that contains the raw view count
-        |        aggregation: DISTINCT
-        |        window: 10d
-        |      }
-        |    }
-        |  }
-        |}
-      """.stripMargin
-
-    val features = Seq("g")
-    val keyField = "x"
-    val featureJoinAsString =
-      s"""
-         | settings: {
-         |  joinTimeSettings: {
-         |   timestampColumn: {
-         |     def: timestamp
-         |     format: yyyy-MM-dd
-         |   }
-         |  }
-         |}
-         |features: [
-         |  {
-         |    key: [$keyField],
-         |    featureList: [${features.mkString(",")}]
-         |  }
-         |]
-      """.stripMargin
-
-
-    /**
-     * Expected output:
-     * +--------+----+----+
-     * |x|   f|   g|
-     * +--------+----+----+
-     * |       1|   6|   2|
-     * |       2|   5|   2|
-     * |       3|   1|   1|
-     * +--------+----+----+
-     */
-    val expectedSchema = StructType(
-      Seq(
-        StructField(keyField, LongType),
-        StructField(features.last, ArrayType(FloatType, false))
-      ))
-    import scala.collection.mutable.WrappedArray
-    val expectedRows = Array(
-      new GenericRowWithSchema(Array(1, Array(10.0f, 11.0f)), expectedSchema),
-      new GenericRowWithSchema(Array(2, Array(10.0f, 11.0f)), expectedSchema),
-      new GenericRowWithSchema(Array(3, Array(9.0f)), expectedSchema)
-    )
-
-    val dfs = runLocalFeatureJoinForTest(featureJoinAsString, featureDefAsString, "featuresWithFilterObs.avro.json").data
-    val result = dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField))
-    val actualRows = result.map(row => Row(row.get(0), row.get(1).asInstanceOf[WrappedArray[Float]].toArray))
-    validateComplexRows(actualRows, expectedRows)
-
-  }
-
-  @Test
-  def testSWADistinctStrings(): Unit = {
-    val featureDefAsString =
-      """
-        |sources: {
-        |  swaSource: {
-        |    location: { path: "generation/daily/" }
-        |    isTimeSeries: true
-        |    timeWindowParameters: {
-        |      timestampColumn: "timestamp"
-        |      timestampColumnFormat: "yyyy-MM-dd"
-        |    }
-        |  }
-        |}
-        |anchors: {
-        |  swaAnchorWithKeyExtractor: {
-        |    source: "swaSource"
-        |    key: [x]
-        |    features: {
-        |     f: {
-        |        def: "user"   // the column that contains the user as string
-        |        aggregation: DISTINCT
-        |        window: 10d
-        |      }
-        |    }
-        |  }
-        |}
-        """.stripMargin
-
-    val features = Seq("f")
-    val keyField = "x"
-    val featureJoinAsString =
-      s"""
-         | settings: {
-         |  joinTimeSettings: {
-         |   timestampColumn: {
-         |     def: timestamp
-         |     format: yyyy-MM-dd
-         |   }
-         |  }
-         |}
-         |features: [
-         |  {
-         |    key: [$keyField],
-         |    featureList: [${features.mkString(",")}]
-         |  }
-         |]
-        """.stripMargin
-
-
-    /**
-     * Expected output:
-     * +--------+----+----+
-     * |x|   f|   g|
-     * +--------+----+----+
-     * |       1|   6|   2|
-     * |       2|   5|   2|
-     * |       3|   1|   1|
-     * +--------+----+----+
-     */
-    val expectedSchema = StructType(
-      Seq(
-        StructField(keyField, LongType),
-        StructField(features.last, ArrayType(StringType, false))
-      ))
-    import scala.collection.mutable.WrappedArray
-    val expectedRows = Array(
-      new GenericRowWithSchema(Array(1, Array("user10", "user11")), expectedSchema),
-      new GenericRowWithSchema(Array(2, Array("user10", "user11")), expectedSchema),
-      new GenericRowWithSchema(Array(3, Array("user9")), expectedSchema)
-    )
-    val dfs = runLocalFeatureJoinForTest(featureJoinAsString, featureDefAsString, "featuresWithFilterObs.avro.json").data
-    val result = dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField))
-    //val col1 = result.map(row => Row(row.get(0)))
-    val col2_initial = result.map(row => Row(row.get(1)))
-
-    val actualRows = new Array[Row](3);
-    var count = 1
-
-    for (row <- col2_initial) {
-      val genericRow = row.get(0);
-      genericRow.getClass;
-      val genericRow2 = genericRow.asInstanceOf[GenericRowWithSchema];
-      val val1 = genericRow2.get(0);
-      val val2 = val1.asInstanceOf[WrappedArray[String]].toArray
-      print(val1)
-      val resultRow = Row(count, val2)
-      actualRows(count - 1) = resultRow
-      count += 1
-    }
-
-    validateComplexRows(actualRows, expectedRows)
-
-  }
-
 }
 

From be93db89971a08f6ff12d48559dbbd1ed289bdc5 Mon Sep 17 00:00:00 2001
From: Andrew Tang <andrew.tang@warnermedia.com>
Date: Thu, 4 Apr 2024 12:17:36 -0700
Subject: [PATCH 2/2] style: run black

---
 feathr_project/feathr/datasets/utils.py       |  1 +
 .../feathr/protobuf/featureValue_pb2.py       | 28 +++++++++----------
 .../registry/_feature_registry_purview.py     |  8 ++++--
 .../udf/_preprocessing_pyudf_manager.py       |  6 ++--
 feathr_project/test/test_feature_registry.py  |  1 +
 feathr_project/test/test_fixture.py           |  6 ++--
 feathr_project/test/test_registry_client.py   |  3 ++
 .../test/unit/utils/test_job_utils.py         |  4 ++-
 .../test/unit/utils/test_platform.py          |  1 +
 9 files changed, 34 insertions(+), 24 deletions(-)

diff --git a/feathr_project/feathr/datasets/utils.py b/feathr_project/feathr/datasets/utils.py
index 5dcfb6e87..8f975c9e2 100644
--- a/feathr_project/feathr/datasets/utils.py
+++ b/feathr_project/feathr/datasets/utils.py
@@ -1,5 +1,6 @@
 """Dataset utilities
 """
+
 import logging
 import math
 from pathlib import Path
diff --git a/feathr_project/feathr/protobuf/featureValue_pb2.py b/feathr_project/feathr/protobuf/featureValue_pb2.py
index 7b3f24a29..2b3224a44 100644
--- a/feathr_project/feathr/protobuf/featureValue_pb2.py
+++ b/feathr_project/feathr/protobuf/featureValue_pb2.py
@@ -37,7 +37,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _FEATUREVALUE,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.FeatureValue)
     },
 )
@@ -48,7 +48,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _BOOLEANARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.BooleanArray)
     },
 )
@@ -59,7 +59,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _STRINGARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.StringArray)
     },
 )
@@ -70,7 +70,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _DOUBLEARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.DoubleArray)
     },
 )
@@ -81,7 +81,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _FLOATARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.FloatArray)
     },
 )
@@ -92,7 +92,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _INTEGERARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.IntegerArray)
     },
 )
@@ -103,7 +103,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _LONGARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.LongArray)
     },
 )
@@ -114,7 +114,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _BYTESARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.BytesArray)
     },
 )
@@ -125,7 +125,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _SPARSESTRINGARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.SparseStringArray)
     },
 )
@@ -136,7 +136,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _SPARSEBOOLARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.SparseBoolArray)
     },
 )
@@ -147,7 +147,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _SPARSEINTEGERARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.SparseIntegerArray)
     },
 )
@@ -158,7 +158,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _SPARSELONGARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.SparseLongArray)
     },
 )
@@ -169,7 +169,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _SPARSEDOUBLEARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.SparseDoubleArray)
     },
 )
@@ -180,7 +180,7 @@
     (_message.Message,),
     {
         "DESCRIPTOR": _SPARSEFLOATARRAY,
-        "__module__": "featureValue_pb2"
+        "__module__": "featureValue_pb2",
         # @@protoc_insertion_point(class_scope:protobuf.SparseFloatArray)
     },
 )
diff --git a/feathr_project/feathr/registry/_feature_registry_purview.py b/feathr_project/feathr/registry/_feature_registry_purview.py
index a8778f87e..5f956a3be 100644
--- a/feathr_project/feathr/registry/_feature_registry_purview.py
+++ b/feathr_project/feathr/registry/_feature_registry_purview.py
@@ -1124,9 +1124,11 @@ def _list_registered_entities_with_details(
                 {
                     "and": [
                         {
-                            "or": [{"entityType": TYPEDEF_FEATHR_PROJECT}]
-                            if TYPEDEF_FEATHR_PROJECT in entity_type_list
-                            else None
+                            "or": (
+                                [{"entityType": TYPEDEF_FEATHR_PROJECT}]
+                                if TYPEDEF_FEATHR_PROJECT in entity_type_list
+                                else None
+                            )
                         },
                         {"attributeName": "qualifiedName", "operator": "startswith", "attributeValue": project_name},
                     ]
diff --git a/feathr_project/feathr/udf/_preprocessing_pyudf_manager.py b/feathr_project/feathr/udf/_preprocessing_pyudf_manager.py
index e400995ff..c35f81826 100644
--- a/feathr_project/feathr/udf/_preprocessing_pyudf_manager.py
+++ b/feathr_project/feathr/udf/_preprocessing_pyudf_manager.py
@@ -58,9 +58,9 @@ def build_anchor_preprocessing_metadata(anchor_list: List[FeatureAnchor], local_
                 feature_names.sort()
                 string_feature_list = ",".join(feature_names)
                 if isinstance(anchor.source.preprocessing, str):
-                    feature_names_to_func_mapping[
-                        string_feature_list
-                    ] = _PreprocessingPyudfManager._parse_function_str_for_name(anchor.source.preprocessing)
+                    feature_names_to_func_mapping[string_feature_list] = (
+                        _PreprocessingPyudfManager._parse_function_str_for_name(anchor.source.preprocessing)
+                    )
                 else:
                     # it's a callable function
                     feature_names_to_func_mapping[string_feature_list] = anchor.source.preprocessing.__name__
diff --git a/feathr_project/test/test_feature_registry.py b/feathr_project/test/test_feature_registry.py
index f7944ef30..ccd0f9623 100644
--- a/feathr_project/test/test_feature_registry.py
+++ b/feathr_project/test/test_feature_registry.py
@@ -79,6 +79,7 @@ def test_feathr_register_features_e2e(self):
 
                 # <ExceptionInfo RuntimeError('Failed to call registry API, status is 409, error is {"message":"Entity feathr_ci_registry_53_3_476999__request_features__f_is_long_trip_distance already exists"}')
             assert "status is 409" in str(exc_info.value)
+
     @pytest.mark.skip(reason="Skipping 502 tests due to server side failure")
     def test_feathr_register_features_partially(self):
         """
diff --git a/feathr_project/test/test_fixture.py b/feathr_project/test/test_fixture.py
index 75539f1a7..024ba8597 100644
--- a/feathr_project/test/test_fixture.py
+++ b/feathr_project/test/test_fixture.py
@@ -367,9 +367,9 @@ def registry_test_setup(config_path: str):
     # Use a new project name every time to make sure all features are registered correctly
     # Project name example: feathr_ci_registry_2022_09_24_01_02_30
     now = datetime.now()
-    os.environ[
-        "project_config__project_name"
-    ] = f'feathr_ci_registry_{str(now)[:19].replace(" ", "_").replace(":", "_").replace("-", "_")}'
+    os.environ["project_config__project_name"] = (
+        f'feathr_ci_registry_{str(now)[:19].replace(" ", "_").replace(":", "_").replace("-", "_")}'
+    )
 
     client = FeathrClient(config_path=config_path, project_registry_tag={"for_test_purpose": "true"})
     request_anchor, agg_anchor, derived_feature_list = generate_entities()
diff --git a/feathr_project/test/test_registry_client.py b/feathr_project/test/test_registry_client.py
index 3fcc08aca..d9028baa7 100644
--- a/feathr_project/test/test_registry_client.py
+++ b/feathr_project/test/test_registry_client.py
@@ -220,6 +220,7 @@ def test_parse_project():
     )
     assert len(derived_features) == 3
 
+
 @pytest.mark.skip(reason="Skip since we cannot connect to external database")
 def test_registry_client_list_features():
     c = _FeatureRegistry(project_name="p", endpoint="https://feathr-sql-registry.azurewebsites.net/api/v1")
@@ -229,6 +230,7 @@ def test_registry_client_list_features():
     for i in f:
         assert i.startswith("feathr_ci_registry_getting_started__")
 
+
 @pytest.mark.skip(reason="Skip since we cannot connect to external database")
 def test_registry_client_load():
     c = _FeatureRegistry(project_name="p", endpoint="https://feathr-sql-registry.azurewebsites.net/api/v1")
@@ -247,6 +249,7 @@ def test_registry_client_load():
     )
     assert len(derived_features) == 2
 
+
 @pytest.mark.skip(reason="Skip since we cannot connect to external database")
 def test_create():
     project_name = f"feathr_registry_client_test_{int(time.time())}"
diff --git a/feathr_project/test/unit/utils/test_job_utils.py b/feathr_project/test/unit/utils/test_job_utils.py
index ae0c8d24d..17f212b67 100644
--- a/feathr_project/test/unit/utils/test_job_utils.py
+++ b/feathr_project/test/unit/utils/test_job_utils.py
@@ -220,7 +220,9 @@ def test__get_result_df(
 #         ("delta", "output-delta", 5),
 #     ],
 # )
-@pytest.mark.skip(reason="Skip since this is not in a spark session. This test should alreayd be covered by `test__get_result_df`. ")
+@pytest.mark.skip(
+    reason="Skip since this is not in a spark session. This test should alreayd be covered by `test__get_result_df`. "
+)
 def test__get_result_df__with_spark_session(
     workspace_dir: str,
     spark: SparkSession,
diff --git a/feathr_project/test/unit/utils/test_platform.py b/feathr_project/test/unit/utils/test_platform.py
index 48a4c4835..545e0f08f 100644
--- a/feathr_project/test/unit/utils/test_platform.py
+++ b/feathr_project/test/unit/utils/test_platform.py
@@ -2,6 +2,7 @@
 Currently, we only test the negative cases, running on non-notebook platform.
 We may submit the test codes to databricks and synapse cluster to confirm the behavior in the future.
 """
+
 from feathr.utils.platform import is_jupyter, is_databricks, is_synapse