diff --git a/CHANGELOG.md b/CHANGELOG.md index b0cdde4..b8883d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 2.0.3 +* Add KafkaWatcher to facilitate testing of writing to Kafka +* Fix a few minor pyflakes warnings and typos + ## 2.0.2 * Fix: #40 write_ext.kafka ignores errors. diff --git a/sparkly/__init__.py b/sparkly/__init__.py index f5d1e5c..616727d 100644 --- a/sparkly/__init__.py +++ b/sparkly/__init__.py @@ -19,4 +19,4 @@ assert SparklySession -__version__ = '2.0.2' +__version__ = '2.0.3' diff --git a/sparkly/catalog.py b/sparkly/catalog.py index 324fbe9..4bfb2a1 100644 --- a/sparkly/catalog.py +++ b/sparkly/catalog.py @@ -14,12 +14,6 @@ # limitations under the License. # -import logging -import re - -from pyspark.sql import DataFrame -from pyspark.sql.types import StructType - class SparklyCatalog(object): """A set of tools to interact with HiveMetastore.""" diff --git a/sparkly/reader.py b/sparkly/reader.py index a1d3355..2a1f447 100644 --- a/sparkly/reader.py +++ b/sparkly/reader.py @@ -233,7 +233,7 @@ def kafka(self, `key` and `value`. Parameters `key_deserializer` and `value_deserializer` are callables - which get's bytes as input and should return python structures as output. + which get bytes as input and should return python structures as output. Args: host (str): Kafka host. diff --git a/sparkly/testing.py b/sparkly/testing.py index 9e0e009..66d5c60 100644 --- a/sparkly/testing.py +++ b/sparkly/testing.py @@ -21,8 +21,9 @@ import shutil from unittest import TestCase -from sparkly.exceptions import FixtureError from sparkly import SparklySession +from sparkly.exceptions import FixtureError +from sparkly.utils import kafka_get_topics_offsets if sys.version_info.major == 3: from http.client import HTTPConnection @@ -46,7 +47,7 @@ MYSQL_FIXTURES_SUPPORT = False try: - from kafka import KafkaProducer + from kafka import KafkaProducer, SimpleClient KAFKA_FIXTURES_SUPPORT = True except ImportError: KAFKA_FIXTURES_SUPPORT = False @@ -447,3 +448,105 @@ def setup_data(self): def teardown_data(self): pass + + +class KafkaWatcher: + """Context manager that tracks Kafka data published to a topic + + Provides access to the new items that were written to a kafka topic by code running + within this context. + + NOTE: This is mainly useful in integration test cases and may produce unexpected results in + production environments, since there are no guarantees about who else may be publishing to + a kafka topic. + + Usage: + my_deserializer = lambda item: json.loads(item.decode('utf-8')) + kafka_watcher = KafkaWatcher( + my_sparkly_session, + expected_output_dataframe_schema, + my_deserializer, + my_deserializer, + 'my.kafkaserver.net', + 'my_kafka_topic', + ) + with kafka_watcher: + # do stuff that publishes messages to 'my_kafka_topic' + self.assertEqual(kafka_watcher.count, expected_number_of_new_messages) + self.assertDataFrameEqual(kafka_watcher.df, expected_df) + """ + + def __init__( + self, + spark, + df_schema, + key_deserializer, + value_deserializer, + host, + topic, + port=9092, + ): + """Initialize context manager + + Parameters `key_deserializer` and `value_deserializer` are callables + which get bytes as input and should return python structures as output. + + Args: + spark (SparklySession): currently active SparklySession + df_schema (pyspark.sql.types.StructType): schema of dataframe to be generated + key_deserializer (function): function used to deserialize the key + value_deserializer (function): function used to deserialize the value + host (basestring): host or ip address of the kafka server to connect to + topic (basestring): Kafka topic to monitor + port (int): port number of the Kafka server to connect to + """ + self.spark = spark + self.topic = topic + self.df_schema = df_schema + self.key_deser, self.val_deser = key_deserializer, value_deserializer + self.host, self.port = host, port + self._df = None + self.count = 0 + + kafka_client = SimpleClient(host) + kafka_client.ensure_topic_exists(topic) + + def __enter__(self): + self._df = None + self.count = 0 + self.pre_offsets = kafka_get_topics_offsets( + topic=self.topic, + host=self.host, + port=self.port, + ) + + def __exit__(self, e_type, e_value, e_trace): + self.post_offsets = kafka_get_topics_offsets( + topic=self.topic, + host=self.host, + port=self.port, + ) + self.count = sum([ + post[2] - pre[2] + for pre, post in zip(self.pre_offsets, self.post_offsets) + ]) + + @property + def df(self): + if not self.count: + return None + if not self._df: + offset_ranges = [ + [pre[0], pre[2], post[2]] + for pre, post in zip(self.pre_offsets, self.post_offsets) + ] + self._df = self.spark.read_ext.kafka( + topic=self.topic, + offset_ranges=offset_ranges, + schema=self.df_schema, + key_deserializer=self.key_deser, + value_deserializer=self.val_deser, + host=self.host, + port=self.port, + ) + return self._df diff --git a/sparkly/utils.py b/sparkly/utils.py index c6450cc..22e3190 100644 --- a/sparkly/utils.py +++ b/sparkly/utils.py @@ -14,10 +14,9 @@ # limitations under the License. # -from collections import OrderedDict +import inspect import os import re -import inspect try: from kafka import SimpleClient diff --git a/tests/integration/resources/test_testing/kafka_watcher_1.json b/tests/integration/resources/test_testing/kafka_watcher_1.json new file mode 100644 index 0000000..389f3f6 --- /dev/null +++ b/tests/integration/resources/test_testing/kafka_watcher_1.json @@ -0,0 +1,4 @@ +{"key": {"user_id": 1}, "value": {"meal": "dinner", "food": ["spaghetti", "meatballs"]}} +{"key": {"user_id": 2}, "value": {"meal": "lunch", "food": ["soylent"]}} +{"key": {"user_id": 3}, "value": {"meal": "breakfast", "food": []}} +{"key": {"user_id": 2}, "value": {"meal": "second dinner", "food": ["galbi", "ice cream"]}} diff --git a/tests/integration/resources/test_testing/kafka_watcher_2.json b/tests/integration/resources/test_testing/kafka_watcher_2.json new file mode 100644 index 0000000..68ca76f --- /dev/null +++ b/tests/integration/resources/test_testing/kafka_watcher_2.json @@ -0,0 +1,3 @@ +{"key": {"user_id": 1}, "value": {"meal": "lunch", "food": ["pizza", "stinky tofu"]}} +{"key": {"user_id": 4}, "value": {"meal": "lunch", "food": ["cuban sandwich", "mashed potatoes"]}} +{"key": {"user_id": 5}, "value": {"meal": "dessert", "food": ["pecan pie", "mango"]}} diff --git a/tests/integration/test_catalog.py b/tests/integration/test_catalog.py index 84ad51d..24a97e2 100644 --- a/tests/integration/test_catalog.py +++ b/tests/integration/test_catalog.py @@ -14,9 +14,6 @@ # limitations under the License. # -import uuid -import os - from sparkly.testing import SparklyGlobalSessionTest from tests.integration.base import SparklyTestSession diff --git a/tests/integration/test_testing.py b/tests/integration/test_testing.py index c382373..084c3e9 100644 --- a/tests/integration/test_testing.py +++ b/tests/integration/test_testing.py @@ -15,18 +15,20 @@ # import json import uuid +import pickle from sparkly.testing import ( CassandraFixture, ElasticFixture, MysqlFixture, SparklyGlobalSessionTest, - KafkaFixture) + KafkaFixture, + KafkaWatcher) from sparkly.utils import absolute_path from tests.integration.base import SparklyTestSession try: - from kafka import KafkaConsumer + from kafka import KafkaConsumer, KafkaProducer except ImportError: pass @@ -156,3 +158,56 @@ def test_kafka_fixture(self): absolute_path(__file__, 'resources', 'test_fixtures', 'kafka.json') ) self.assertDataFrameEqual(expected_data, actual_data) + + +class TestKafkaWatcher(SparklyGlobalSessionTest): + session = SparklyTestSession + + def test_write_kafka_dataframe(self): + host = 'kafka.docker' + topic = 'test.topic.kafkawatcher.{}'.format(uuid.uuid4().hex[:10]) + port = 9092 + input_df, expected_data = self.get_test_data('kafka_watcher_1.json') + + kafka_watcher = KafkaWatcher( + self.spark, + input_df.schema, + pickle.loads, + pickle.loads, + host, + topic, + port, + ) + with kafka_watcher: + expected_count = self.write_data(input_df, host, topic, port) + self.assertEqual(kafka_watcher.count, expected_count) + self.assertDataFrameEqual(kafka_watcher.df, expected_data) + + with kafka_watcher: + pass + self.assertEqual(kafka_watcher.count, 0) + self.assertIsNone(kafka_watcher.df, None) + + input_df, expected_data = self.get_test_data('kafka_watcher_2.json') + with kafka_watcher: + expected_count = self.write_data(input_df, host, topic, port) + self.assertEqual(kafka_watcher.count, expected_count) + self.assertDataFrameEqual(kafka_watcher.df, expected_data) + + def get_test_data(self, filename): + file_path = absolute_path(__file__, 'resources', 'test_testing', filename) + df = self.spark.read.json(file_path) + data = [item.asDict(recursive=True) for item in df.collect()] + return df, data + + def write_data(self, df, host, topic, port): + producer = KafkaProducer( + bootstrap_servers=['{}:{}'.format(host, port)], + key_serializer=pickle.dumps, + value_serializer=pickle.dumps, + ) + rows = df.collect() + for row in rows: + producer.send(topic, key=row.key, value=row.value) + producer.flush() + return len(rows) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 02f61d4..0cd3801 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -16,7 +16,6 @@ from unittest import TestCase -from sparkly.exceptions import UnsupportedDataType from sparkly.utils import parse_schema