Skip to content

Commit

Permalink
IMP add argmax function
Browse files Browse the repository at this point in the history
This function gets the value of a row that maximizes some
other set of columns out of a group or dataframe.
  • Loading branch information
qfliu authored and albertc1 committed Feb 5, 2019
1 parent d74b090 commit 89b6842
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 2 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## 2.6.0
* Add argmax function to sparkly.functions

## 2.5.1
* Fix port issue with reading and writing `by_url`. `urlparse` return `netloc` with port, which breaks read and write from MySQL and Cassandra.

Expand Down Expand Up @@ -28,7 +31,7 @@

# 2.2.1
* `spark.sql.shuffle.partitions` in `SparklyTest` should be set to string,
because `int` value breaks integration testing in Spark 2.0.2.
because `int` value breaks integration testing in Spark 2.0.2.

# 2.2.0
* Add instant iterative development mode. `sparkly-testing --help` for more details.
Expand Down
1 change: 1 addition & 0 deletions docs/source/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ A counterpart of pyspark.sql.functions providing useful shortcuts:

- a cleaner alternative to chaining together multiple when/otherwise statements.
- an easy way to join multiple dataframes at once and disambiguate fields with the same name.
- agg function to select a value from the row that maximizes other column(s)


API documentation
Expand Down
2 changes: 1 addition & 1 deletion sparkly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
assert SparklySession


__version__ = '2.5.1'
__version__ = '2.6.0'
38 changes: 38 additions & 0 deletions sparkly/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import defaultdict
from functools import reduce
import operator
from six import string_types

from pyspark.sql import Column
from pyspark.sql import functions as F
Expand Down Expand Up @@ -153,3 +154,40 @@ def _execute_case(accumulator, case):
result = reduce(_execute_case, cases, F).otherwise(default)

return result


def argmax(field, by, condition=None):
"""Select a value from the row that maximizes other column(s)
Args:
field (string, pyspark.sql.Column): the field to return that maximizes the "by" columns
by (*string, *pyspark.sql.Column): field or list of fields to maximize. In reality, this
will usually be only one field. But you may use multiple for tiebreakers
condition (optional): Only consider the entities that pass this condition
Returns:
pyspark.sql.Column
Example:
df = (
df
.groupBy('id')
.agg(argmax('field1', 'by_field'))
)
argmax('field1', ['by_field1', 'by_field2'], condition=F.col('col') == 1)
argmax(F.col('field1'), [F.col('by_field1'), F.col('by_field2')], condition=F.lit(True))
"""
if not isinstance(by, list):
by = [by]

if isinstance(field, string_types):
field = F.col(field)

by.append(field.alias('__tmp_argmax__'))
result = F.struct(*by)
if condition is not None:
result = F.when(condition, result)
result = F.max(result).getField('__tmp_argmax__')

return result
180 changes: 180 additions & 0 deletions tests/integration/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,183 @@ def test_switch_case_with_custom_operand_lt(self):
{'value': 0, 'value_2': 'worst'},
],
)

class TestArgmax(SparklyGlobalSessionTest):
session = SparklyTestSession

def test_non_nullable_values(self):
df = self.spark.createDataFrame(
data=[
('1', 'test1', None, 3),
('1', None, 2, 4),
('2', 'test2', 3, 1),
('2', 'test3', 4, 2),
],
schema=T.StructType([
T.StructField('id', T.StringType(), nullable=True),
T.StructField('value1', T.StringType(), nullable=True),
T.StructField('value2', T.IntegerType(), nullable=True),
T.StructField('target', T.IntegerType(), nullable=True),
]),
)

df = (
df
.groupBy('id')
.agg(
F.max('target').alias('target'),
*[
SF.argmax(col, 'target', condition=F.col(col).isNotNull()).alias(col)
for col in df.columns
if col not in ['id', 'target']
]
)
)

self.assertDataFrameEqual(
df,
[
{'id': '1', 'target': 4, 'value1': 'test1', 'value2': 2},
{'id': '2', 'target': 2, 'value1': 'test3', 'value2': 4},
],
)

def test_nullable_values(self):
df = self.spark.createDataFrame(
data=[
('1', 'test1', None, 3),
('1', None, 2, 4),
('2', 'test2', 3, 1),
('2', 'test3', 4, 2),
],
schema=T.StructType([
T.StructField('id', T.StringType(), nullable=True),
T.StructField('value1', T.StringType(), nullable=True),
T.StructField('value2', T.IntegerType(), nullable=True),
T.StructField('target', T.IntegerType(), nullable=True),
]),
)

df = (
df
.groupBy('id')
.agg(
F.max('target').alias('target'),
*[
SF.argmax(col, 'target').alias(col)
for col in df.columns
if col not in ['id', 'target']
]
)
)

self.assertDataFrameEqual(
df,
[
{'id': '1', 'target': 4, 'value1': None, 'value2': 2},
{'id': '2', 'target': 2, 'value1': 'test3', 'value2': 4},
],
)

def test_break_ties(self):
df = self.spark.createDataFrame(
data=[
('1', 'test1', 1, 4),
('1', 'test2', 1, 3),
('2', 'test3', 1, 4),
('2', 'test4', 2, 3),
],
schema=T.StructType([
T.StructField('id', T.StringType(), nullable=True),
T.StructField('value', T.StringType(), nullable=True),
T.StructField('target1', T.IntegerType(), nullable=True),
T.StructField('target2', T.IntegerType(), nullable=True),
]),
)

df = (
df
.groupBy('id')
.agg(
SF.argmax('value', ['target1', 'target2']).alias('value')
)
)

self.assertDataFrameEqual(
df,
[
{'id': '1', 'value': 'test1'},
{'id': '2', 'value': 'test4'},
],
)

def test_with_conditions(self):
df = self.spark.createDataFrame(
data=[
('1', 'test1', 2),
('1', 'test2', 1),
('2', 'test3', 1),
('2', 'test4', 2),
],
schema=T.StructType([
T.StructField('id', T.StringType(), nullable=True),
T.StructField('value', T.StringType(), nullable=True),
T.StructField('target1', T.IntegerType(), nullable=True),
]),
)

df = (
df
.groupBy('id')
.agg(
SF.argmax(
'value',
'target1',
condition=F.col('value') != 'test1',
).alias('value'),
)
)

self.assertDataFrameEqual(
df,
[
{'id': '1', 'value': 'test2'},
{'id': '2', 'value': 'test4'},
],
)

def test_with_column_expressions(self):
df = self.spark.createDataFrame(
data=[
('1', None, 'test1', 1, 4),
('1', 'test2', 'test2_1', 1, 3),
('2', 'test3', None, 1, 4),
('2', 'test4', 'test5', 2, 6),
],
schema=T.StructType([
T.StructField('id', T.StringType(), nullable=True),
T.StructField('value1', T.StringType(), nullable=True),
T.StructField('value2', T.StringType(), nullable=True),
T.StructField('target1', T.IntegerType(), nullable=True),
T.StructField('target2', T.IntegerType(), nullable=True),
]),
)

df = (
df
.groupBy('id')
.agg(
SF.argmax(
F.coalesce(F.col('value1'), F.col('value2')),
F.col('target1') + F.col('target2'),
).alias('value'),
)
)

self.assertDataFrameEqual(
df,
[
{'id': '1', 'value': 'test1'},
{'id': '2', 'value': 'test4'},
],
)

0 comments on commit 89b6842

Please sign in to comment.