Skip to content

Commit

Permalink
Merge pull request #42 from Tubular/imp/add-kafkawatcher-context-mana…
Browse files Browse the repository at this point in the history
…ger-for-testing

Imp/add kafkawatcher context manager for testing
  • Loading branch information
albertc1 authored May 9, 2017
2 parents c2ed31c + e1795fa commit ba137da
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 18 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 2.0.3
* Add KafkaWatcher to facilitate testing of writing to Kafka
* Fix a few minor pyflakes warnings and typos

## 2.0.2
* Fix: #40 write_ext.kafka ignores errors.

Expand Down
2 changes: 1 addition & 1 deletion sparkly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
assert SparklySession


__version__ = '2.0.2'
__version__ = '2.0.3'
6 changes: 0 additions & 6 deletions sparkly/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
# limitations under the License.
#

import logging
import re

from pyspark.sql import DataFrame
from pyspark.sql.types import StructType


class SparklyCatalog(object):
"""A set of tools to interact with HiveMetastore."""
Expand Down
2 changes: 1 addition & 1 deletion sparkly/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def kafka(self,
`key` and `value`.
Parameters `key_deserializer` and `value_deserializer` are callables
which get's bytes as input and should return python structures as output.
which get bytes as input and should return python structures as output.
Args:
host (str): Kafka host.
Expand Down
107 changes: 105 additions & 2 deletions sparkly/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import shutil
from unittest import TestCase

from sparkly.exceptions import FixtureError
from sparkly import SparklySession
from sparkly.exceptions import FixtureError
from sparkly.utils import kafka_get_topics_offsets

if sys.version_info.major == 3:
from http.client import HTTPConnection
Expand All @@ -46,7 +47,7 @@
MYSQL_FIXTURES_SUPPORT = False

try:
from kafka import KafkaProducer
from kafka import KafkaProducer, SimpleClient
KAFKA_FIXTURES_SUPPORT = True
except ImportError:
KAFKA_FIXTURES_SUPPORT = False
Expand Down Expand Up @@ -447,3 +448,105 @@ def setup_data(self):

def teardown_data(self):
pass


class KafkaWatcher:
"""Context manager that tracks Kafka data published to a topic
Provides access to the new items that were written to a kafka topic by code running
within this context.
NOTE: This is mainly useful in integration test cases and may produce unexpected results in
production environments, since there are no guarantees about who else may be publishing to
a kafka topic.
Usage:
my_deserializer = lambda item: json.loads(item.decode('utf-8'))
kafka_watcher = KafkaWatcher(
my_sparkly_session,
expected_output_dataframe_schema,
my_deserializer,
my_deserializer,
'my.kafkaserver.net',
'my_kafka_topic',
)
with kafka_watcher:
# do stuff that publishes messages to 'my_kafka_topic'
self.assertEqual(kafka_watcher.count, expected_number_of_new_messages)
self.assertDataFrameEqual(kafka_watcher.df, expected_df)
"""

def __init__(
self,
spark,
df_schema,
key_deserializer,
value_deserializer,
host,
topic,
port=9092,
):
"""Initialize context manager
Parameters `key_deserializer` and `value_deserializer` are callables
which get bytes as input and should return python structures as output.
Args:
spark (SparklySession): currently active SparklySession
df_schema (pyspark.sql.types.StructType): schema of dataframe to be generated
key_deserializer (function): function used to deserialize the key
value_deserializer (function): function used to deserialize the value
host (basestring): host or ip address of the kafka server to connect to
topic (basestring): Kafka topic to monitor
port (int): port number of the Kafka server to connect to
"""
self.spark = spark
self.topic = topic
self.df_schema = df_schema
self.key_deser, self.val_deser = key_deserializer, value_deserializer
self.host, self.port = host, port
self._df = None
self.count = 0

kafka_client = SimpleClient(host)
kafka_client.ensure_topic_exists(topic)

def __enter__(self):
self._df = None
self.count = 0
self.pre_offsets = kafka_get_topics_offsets(
topic=self.topic,
host=self.host,
port=self.port,
)

def __exit__(self, e_type, e_value, e_trace):
self.post_offsets = kafka_get_topics_offsets(
topic=self.topic,
host=self.host,
port=self.port,
)
self.count = sum([
post[2] - pre[2]
for pre, post in zip(self.pre_offsets, self.post_offsets)
])

@property
def df(self):
if not self.count:
return None
if not self._df:
offset_ranges = [
[pre[0], pre[2], post[2]]
for pre, post in zip(self.pre_offsets, self.post_offsets)
]
self._df = self.spark.read_ext.kafka(
topic=self.topic,
offset_ranges=offset_ranges,
schema=self.df_schema,
key_deserializer=self.key_deser,
value_deserializer=self.val_deser,
host=self.host,
port=self.port,
)
return self._df
3 changes: 1 addition & 2 deletions sparkly/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# limitations under the License.
#

from collections import OrderedDict
import inspect
import os
import re
import inspect

try:
from kafka import SimpleClient
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/resources/test_testing/kafka_watcher_1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"key": {"user_id": 1}, "value": {"meal": "dinner", "food": ["spaghetti", "meatballs"]}}
{"key": {"user_id": 2}, "value": {"meal": "lunch", "food": ["soylent"]}}
{"key": {"user_id": 3}, "value": {"meal": "breakfast", "food": []}}
{"key": {"user_id": 2}, "value": {"meal": "second dinner", "food": ["galbi", "ice cream"]}}
3 changes: 3 additions & 0 deletions tests/integration/resources/test_testing/kafka_watcher_2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"key": {"user_id": 1}, "value": {"meal": "lunch", "food": ["pizza", "stinky tofu"]}}
{"key": {"user_id": 4}, "value": {"meal": "lunch", "food": ["cuban sandwich", "mashed potatoes"]}}
{"key": {"user_id": 5}, "value": {"meal": "dessert", "food": ["pecan pie", "mango"]}}
3 changes: 0 additions & 3 deletions tests/integration/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
# limitations under the License.
#

import uuid
import os

from sparkly.testing import SparklyGlobalSessionTest
from tests.integration.base import SparklyTestSession

Expand Down
59 changes: 57 additions & 2 deletions tests/integration/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
#
import json
import uuid
import pickle

from sparkly.testing import (
CassandraFixture,
ElasticFixture,
MysqlFixture,
SparklyGlobalSessionTest,
KafkaFixture)
KafkaFixture,
KafkaWatcher)
from sparkly.utils import absolute_path
from tests.integration.base import SparklyTestSession

try:
from kafka import KafkaConsumer
from kafka import KafkaConsumer, KafkaProducer
except ImportError:
pass

Expand Down Expand Up @@ -156,3 +158,56 @@ def test_kafka_fixture(self):
absolute_path(__file__, 'resources', 'test_fixtures', 'kafka.json')
)
self.assertDataFrameEqual(expected_data, actual_data)


class TestKafkaWatcher(SparklyGlobalSessionTest):
session = SparklyTestSession

def test_write_kafka_dataframe(self):
host = 'kafka.docker'
topic = 'test.topic.kafkawatcher.{}'.format(uuid.uuid4().hex[:10])
port = 9092
input_df, expected_data = self.get_test_data('kafka_watcher_1.json')

kafka_watcher = KafkaWatcher(
self.spark,
input_df.schema,
pickle.loads,
pickle.loads,
host,
topic,
port,
)
with kafka_watcher:
expected_count = self.write_data(input_df, host, topic, port)
self.assertEqual(kafka_watcher.count, expected_count)
self.assertDataFrameEqual(kafka_watcher.df, expected_data)

with kafka_watcher:
pass
self.assertEqual(kafka_watcher.count, 0)
self.assertIsNone(kafka_watcher.df, None)

input_df, expected_data = self.get_test_data('kafka_watcher_2.json')
with kafka_watcher:
expected_count = self.write_data(input_df, host, topic, port)
self.assertEqual(kafka_watcher.count, expected_count)
self.assertDataFrameEqual(kafka_watcher.df, expected_data)

def get_test_data(self, filename):
file_path = absolute_path(__file__, 'resources', 'test_testing', filename)
df = self.spark.read.json(file_path)
data = [item.asDict(recursive=True) for item in df.collect()]
return df, data

def write_data(self, df, host, topic, port):
producer = KafkaProducer(
bootstrap_servers=['{}:{}'.format(host, port)],
key_serializer=pickle.dumps,
value_serializer=pickle.dumps,
)
rows = df.collect()
for row in rows:
producer.send(topic, key=row.key, value=row.value)
producer.flush()
return len(rows)
1 change: 0 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from unittest import TestCase

from sparkly.exceptions import UnsupportedDataType
from sparkly.utils import parse_schema


Expand Down

0 comments on commit ba137da

Please sign in to comment.