Skip to content

Commit

Permalink
IMP add KafkaWatcher testing util
Browse files Browse the repository at this point in the history
This context manager can be used to make it easier to test
Spark jobs that write to Kafka.
  • Loading branch information
albertc1 committed May 9, 2017
1 parent c2402fb commit e1795fa
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 2.0.3
* Add KafkaWatcher to facilitate testing of writing to Kafka
* Fix a few minor pyflakes warnings and typos

## 2.0.2
* Fix: #40 write_ext.kafka ignores errors.

Expand Down
2 changes: 1 addition & 1 deletion sparkly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
assert SparklySession


__version__ = '2.0.2'
__version__ = '2.0.3'
107 changes: 105 additions & 2 deletions sparkly/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import shutil
from unittest import TestCase

from sparkly.exceptions import FixtureError
from sparkly import SparklySession
from sparkly.exceptions import FixtureError
from sparkly.utils import kafka_get_topics_offsets

if sys.version_info.major == 3:
from http.client import HTTPConnection
Expand All @@ -46,7 +47,7 @@
MYSQL_FIXTURES_SUPPORT = False

try:
from kafka import KafkaProducer
from kafka import KafkaProducer, SimpleClient
KAFKA_FIXTURES_SUPPORT = True
except ImportError:
KAFKA_FIXTURES_SUPPORT = False
Expand Down Expand Up @@ -447,3 +448,105 @@ def setup_data(self):

def teardown_data(self):
pass


class KafkaWatcher:
"""Context manager that tracks Kafka data published to a topic
Provides access to the new items that were written to a kafka topic by code running
within this context.
NOTE: This is mainly useful in integration test cases and may produce unexpected results in
production environments, since there are no guarantees about who else may be publishing to
a kafka topic.
Usage:
my_deserializer = lambda item: json.loads(item.decode('utf-8'))
kafka_watcher = KafkaWatcher(
my_sparkly_session,
expected_output_dataframe_schema,
my_deserializer,
my_deserializer,
'my.kafkaserver.net',
'my_kafka_topic',
)
with kafka_watcher:
# do stuff that publishes messages to 'my_kafka_topic'
self.assertEqual(kafka_watcher.count, expected_number_of_new_messages)
self.assertDataFrameEqual(kafka_watcher.df, expected_df)
"""

def __init__(
self,
spark,
df_schema,
key_deserializer,
value_deserializer,
host,
topic,
port=9092,
):
"""Initialize context manager
Parameters `key_deserializer` and `value_deserializer` are callables
which get bytes as input and should return python structures as output.
Args:
spark (SparklySession): currently active SparklySession
df_schema (pyspark.sql.types.StructType): schema of dataframe to be generated
key_deserializer (function): function used to deserialize the key
value_deserializer (function): function used to deserialize the value
host (basestring): host or ip address of the kafka server to connect to
topic (basestring): Kafka topic to monitor
port (int): port number of the Kafka server to connect to
"""
self.spark = spark
self.topic = topic
self.df_schema = df_schema
self.key_deser, self.val_deser = key_deserializer, value_deserializer
self.host, self.port = host, port
self._df = None
self.count = 0

kafka_client = SimpleClient(host)
kafka_client.ensure_topic_exists(topic)

def __enter__(self):
self._df = None
self.count = 0
self.pre_offsets = kafka_get_topics_offsets(
topic=self.topic,
host=self.host,
port=self.port,
)

def __exit__(self, e_type, e_value, e_trace):
self.post_offsets = kafka_get_topics_offsets(
topic=self.topic,
host=self.host,
port=self.port,
)
self.count = sum([
post[2] - pre[2]
for pre, post in zip(self.pre_offsets, self.post_offsets)
])

@property
def df(self):
if not self.count:
return None
if not self._df:
offset_ranges = [
[pre[0], pre[2], post[2]]
for pre, post in zip(self.pre_offsets, self.post_offsets)
]
self._df = self.spark.read_ext.kafka(
topic=self.topic,
offset_ranges=offset_ranges,
schema=self.df_schema,
key_deserializer=self.key_deser,
value_deserializer=self.val_deser,
host=self.host,
port=self.port,
)
return self._df
4 changes: 4 additions & 0 deletions tests/integration/resources/test_testing/kafka_watcher_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"key": {"user_id": 1}, "value": {"meal": "dinner", "food": ["spaghetti", "meatballs"]}}
{"key": {"user_id": 2}, "value": {"meal": "lunch", "food": ["soylent"]}}
{"key": {"user_id": 3}, "value": {"meal": "breakfast", "food": []}}
{"key": {"user_id": 2}, "value": {"meal": "second dinner", "food": ["galbi", "ice cream"]}}
3 changes: 3 additions & 0 deletions tests/integration/resources/test_testing/kafka_watcher_2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"key": {"user_id": 1}, "value": {"meal": "lunch", "food": ["pizza", "stinky tofu"]}}
{"key": {"user_id": 4}, "value": {"meal": "lunch", "food": ["cuban sandwich", "mashed potatoes"]}}
{"key": {"user_id": 5}, "value": {"meal": "dessert", "food": ["pecan pie", "mango"]}}
59 changes: 57 additions & 2 deletions tests/integration/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
#
import json
import uuid
import pickle

from sparkly.testing import (
CassandraFixture,
ElasticFixture,
MysqlFixture,
SparklyGlobalSessionTest,
KafkaFixture)
KafkaFixture,
KafkaWatcher)
from sparkly.utils import absolute_path
from tests.integration.base import SparklyTestSession

try:
from kafka import KafkaConsumer
from kafka import KafkaConsumer, KafkaProducer
except ImportError:
pass

Expand Down Expand Up @@ -156,3 +158,56 @@ def test_kafka_fixture(self):
absolute_path(__file__, 'resources', 'test_fixtures', 'kafka.json')
)
self.assertDataFrameEqual(expected_data, actual_data)


class TestKafkaWatcher(SparklyGlobalSessionTest):
session = SparklyTestSession

def test_write_kafka_dataframe(self):
host = 'kafka.docker'
topic = 'test.topic.kafkawatcher.{}'.format(uuid.uuid4().hex[:10])
port = 9092
input_df, expected_data = self.get_test_data('kafka_watcher_1.json')

kafka_watcher = KafkaWatcher(
self.spark,
input_df.schema,
pickle.loads,
pickle.loads,
host,
topic,
port,
)
with kafka_watcher:
expected_count = self.write_data(input_df, host, topic, port)
self.assertEqual(kafka_watcher.count, expected_count)
self.assertDataFrameEqual(kafka_watcher.df, expected_data)

with kafka_watcher:
pass
self.assertEqual(kafka_watcher.count, 0)
self.assertIsNone(kafka_watcher.df, None)

input_df, expected_data = self.get_test_data('kafka_watcher_2.json')
with kafka_watcher:
expected_count = self.write_data(input_df, host, topic, port)
self.assertEqual(kafka_watcher.count, expected_count)
self.assertDataFrameEqual(kafka_watcher.df, expected_data)

def get_test_data(self, filename):
file_path = absolute_path(__file__, 'resources', 'test_testing', filename)
df = self.spark.read.json(file_path)
data = [item.asDict(recursive=True) for item in df.collect()]
return df, data

def write_data(self, df, host, topic, port):
producer = KafkaProducer(
bootstrap_servers=['{}:{}'.format(host, port)],
key_serializer=pickle.dumps,
value_serializer=pickle.dumps,
)
rows = df.collect()
for row in rows:
producer.send(topic, key=row.key, value=row.value)
producer.flush()
return len(rows)

0 comments on commit e1795fa

Please sign in to comment.