From 89b6842feed1a0f64ab4ca4503a42105ed5c9f08 Mon Sep 17 00:00:00 2001 From: qfliu Date: Fri, 7 Dec 2018 13:22:42 -0800 Subject: [PATCH] IMP add argmax function This function gets the value of a row that maximizes some other set of columns out of a group or dataframe. --- CHANGELOG.md | 5 +- docs/source/functions.rst | 1 + sparkly/__init__.py | 2 +- sparkly/functions.py | 38 ++++++ tests/integration/test_functions.py | 180 ++++++++++++++++++++++++++++ 5 files changed, 224 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a0281c..81d739a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +## 2.6.0 +* Add argmax function to sparkly.functions + ## 2.5.1 * Fix port issue with reading and writing `by_url`. `urlparse` return `netloc` with port, which breaks read and write from MySQL and Cassandra. @@ -28,7 +31,7 @@ # 2.2.1 * `spark.sql.shuffle.partitions` in `SparklyTest` should be set to string, -because `int` value breaks integration testing in Spark 2.0.2. +because `int` value breaks integration testing in Spark 2.0.2. # 2.2.0 * Add instant iterative development mode. `sparkly-testing --help` for more details. diff --git a/docs/source/functions.rst b/docs/source/functions.rst index 71ad12c..8b01e28 100644 --- a/docs/source/functions.rst +++ b/docs/source/functions.rst @@ -5,6 +5,7 @@ A counterpart of pyspark.sql.functions providing useful shortcuts: - a cleaner alternative to chaining together multiple when/otherwise statements. - an easy way to join multiple dataframes at once and disambiguate fields with the same name. +- agg function to select a value from the row that maximizes other column(s) API documentation diff --git a/sparkly/__init__.py b/sparkly/__init__.py index af8573f..97a6317 100644 --- a/sparkly/__init__.py +++ b/sparkly/__init__.py @@ -19,4 +19,4 @@ assert SparklySession -__version__ = '2.5.1' +__version__ = '2.6.0' diff --git a/sparkly/functions.py b/sparkly/functions.py index baa3586..48cf55a 100644 --- a/sparkly/functions.py +++ b/sparkly/functions.py @@ -17,6 +17,7 @@ from collections import defaultdict from functools import reduce import operator +from six import string_types from pyspark.sql import Column from pyspark.sql import functions as F @@ -153,3 +154,40 @@ def _execute_case(accumulator, case): result = reduce(_execute_case, cases, F).otherwise(default) return result + + +def argmax(field, by, condition=None): + """Select a value from the row that maximizes other column(s) + + Args: + field (string, pyspark.sql.Column): the field to return that maximizes the "by" columns + by (*string, *pyspark.sql.Column): field or list of fields to maximize. In reality, this + will usually be only one field. But you may use multiple for tiebreakers + condition (optional): Only consider the entities that pass this condition + + Returns: + pyspark.sql.Column + + Example: + df = ( + df + .groupBy('id') + .agg(argmax('field1', 'by_field')) + ) + + argmax('field1', ['by_field1', 'by_field2'], condition=F.col('col') == 1) + argmax(F.col('field1'), [F.col('by_field1'), F.col('by_field2')], condition=F.lit(True)) + """ + if not isinstance(by, list): + by = [by] + + if isinstance(field, string_types): + field = F.col(field) + + by.append(field.alias('__tmp_argmax__')) + result = F.struct(*by) + if condition is not None: + result = F.when(condition, result) + result = F.max(result).getField('__tmp_argmax__') + + return result diff --git a/tests/integration/test_functions.py b/tests/integration/test_functions.py index 2560a09..e4ca2da 100644 --- a/tests/integration/test_functions.py +++ b/tests/integration/test_functions.py @@ -359,3 +359,183 @@ def test_switch_case_with_custom_operand_lt(self): {'value': 0, 'value_2': 'worst'}, ], ) + +class TestArgmax(SparklyGlobalSessionTest): + session = SparklyTestSession + + def test_non_nullable_values(self): + df = self.spark.createDataFrame( + data=[ + ('1', 'test1', None, 3), + ('1', None, 2, 4), + ('2', 'test2', 3, 1), + ('2', 'test3', 4, 2), + ], + schema=T.StructType([ + T.StructField('id', T.StringType(), nullable=True), + T.StructField('value1', T.StringType(), nullable=True), + T.StructField('value2', T.IntegerType(), nullable=True), + T.StructField('target', T.IntegerType(), nullable=True), + ]), + ) + + df = ( + df + .groupBy('id') + .agg( + F.max('target').alias('target'), + *[ + SF.argmax(col, 'target', condition=F.col(col).isNotNull()).alias(col) + for col in df.columns + if col not in ['id', 'target'] + ] + ) + ) + + self.assertDataFrameEqual( + df, + [ + {'id': '1', 'target': 4, 'value1': 'test1', 'value2': 2}, + {'id': '2', 'target': 2, 'value1': 'test3', 'value2': 4}, + ], + ) + + def test_nullable_values(self): + df = self.spark.createDataFrame( + data=[ + ('1', 'test1', None, 3), + ('1', None, 2, 4), + ('2', 'test2', 3, 1), + ('2', 'test3', 4, 2), + ], + schema=T.StructType([ + T.StructField('id', T.StringType(), nullable=True), + T.StructField('value1', T.StringType(), nullable=True), + T.StructField('value2', T.IntegerType(), nullable=True), + T.StructField('target', T.IntegerType(), nullable=True), + ]), + ) + + df = ( + df + .groupBy('id') + .agg( + F.max('target').alias('target'), + *[ + SF.argmax(col, 'target').alias(col) + for col in df.columns + if col not in ['id', 'target'] + ] + ) + ) + + self.assertDataFrameEqual( + df, + [ + {'id': '1', 'target': 4, 'value1': None, 'value2': 2}, + {'id': '2', 'target': 2, 'value1': 'test3', 'value2': 4}, + ], + ) + + def test_break_ties(self): + df = self.spark.createDataFrame( + data=[ + ('1', 'test1', 1, 4), + ('1', 'test2', 1, 3), + ('2', 'test3', 1, 4), + ('2', 'test4', 2, 3), + ], + schema=T.StructType([ + T.StructField('id', T.StringType(), nullable=True), + T.StructField('value', T.StringType(), nullable=True), + T.StructField('target1', T.IntegerType(), nullable=True), + T.StructField('target2', T.IntegerType(), nullable=True), + ]), + ) + + df = ( + df + .groupBy('id') + .agg( + SF.argmax('value', ['target1', 'target2']).alias('value') + ) + ) + + self.assertDataFrameEqual( + df, + [ + {'id': '1', 'value': 'test1'}, + {'id': '2', 'value': 'test4'}, + ], + ) + + def test_with_conditions(self): + df = self.spark.createDataFrame( + data=[ + ('1', 'test1', 2), + ('1', 'test2', 1), + ('2', 'test3', 1), + ('2', 'test4', 2), + ], + schema=T.StructType([ + T.StructField('id', T.StringType(), nullable=True), + T.StructField('value', T.StringType(), nullable=True), + T.StructField('target1', T.IntegerType(), nullable=True), + ]), + ) + + df = ( + df + .groupBy('id') + .agg( + SF.argmax( + 'value', + 'target1', + condition=F.col('value') != 'test1', + ).alias('value'), + ) + ) + + self.assertDataFrameEqual( + df, + [ + {'id': '1', 'value': 'test2'}, + {'id': '2', 'value': 'test4'}, + ], + ) + + def test_with_column_expressions(self): + df = self.spark.createDataFrame( + data=[ + ('1', None, 'test1', 1, 4), + ('1', 'test2', 'test2_1', 1, 3), + ('2', 'test3', None, 1, 4), + ('2', 'test4', 'test5', 2, 6), + ], + schema=T.StructType([ + T.StructField('id', T.StringType(), nullable=True), + T.StructField('value1', T.StringType(), nullable=True), + T.StructField('value2', T.StringType(), nullable=True), + T.StructField('target1', T.IntegerType(), nullable=True), + T.StructField('target2', T.IntegerType(), nullable=True), + ]), + ) + + df = ( + df + .groupBy('id') + .agg( + SF.argmax( + F.coalesce(F.col('value1'), F.col('value2')), + F.col('target1') + F.col('target2'), + ).alias('value'), + ) + ) + + self.assertDataFrameEqual( + df, + [ + {'id': '1', 'value': 'test1'}, + {'id': '2', 'value': 'test4'}, + ], + )