diff --git a/kafka-test/src/test/scala/org/apache/pekko/projection/kafka/internal/KafkaSourceProviderImplSpec.scala b/kafka-test/src/test/scala/org/apache/pekko/projection/kafka/internal/KafkaSourceProviderImplSpec.scala index 62a9fc74..95522295 100644 --- a/kafka-test/src/test/scala/org/apache/pekko/projection/kafka/internal/KafkaSourceProviderImplSpec.scala +++ b/kafka-test/src/test/scala/org/apache/pekko/projection/kafka/internal/KafkaSourceProviderImplSpec.scala @@ -37,13 +37,11 @@ import org.scalatest.wordspec.AnyWordSpecLike object KafkaSourceProviderImplSpec { private val TestProjectionId = ProjectionId("test-projection", "00") - def handler(probe: TestProbe[ConsumerRecord[String, String]], - assertFunction: TestProbe[ConsumerRecord[String, String]] => Future[Done]) - : Handler[ConsumerRecord[String, String]] = + def handler(probe: TestProbe[ConsumerRecord[String, String]]): Handler[ConsumerRecord[String, String]] = new Handler[ConsumerRecord[String, String]] { override def process(env: ConsumerRecord[String, String]): Future[Done] = { probe.ref ! env - assertFunction(probe) + Future.successful(Done) } } @@ -72,9 +70,10 @@ class KafkaSourceProviderImplSpec extends ScalaTestWithActorTestKit with LogCapt val metadataClient = new TestMetadataClientAdapter(partitions) val tp0 = new TopicPartition(topic, 0) val tp1 = new TopicPartition(topic, 1) + val totalPerPartition = 10 val consumerRecords = - for (n <- 0 to 10; tp <- List(tp0, tp1)) + for (n <- 0 to totalPerPartition; tp <- List(tp0, tp1)) yield new ConsumerRecord(tp.topic(), tp.partition(), n, n.toString, n.toString) val consumerSource = Source(consumerRecords) @@ -95,14 +94,7 @@ class KafkaSourceProviderImplSpec extends ScalaTestWithActorTestKit with LogCapt } val probe = testKit.createTestProbe[ConsumerRecord[String, String]]() - val records = Set.empty[ConsumerRecord[String, String]] - val projection = TestProjection(TestProjectionId, provider, - () => - handler(probe, - p => { - records ++= p.receiveMessage() - Future.successful(Done) - })) + val projection = TestProjection(TestProjectionId, provider, () => handler(probe)) projectionTestKit.runWithTestSink(projection) { sinkProbe => provider.partitionHandler.onAssign(Set(tp0, tp1), null) @@ -110,13 +102,20 @@ class KafkaSourceProviderImplSpec extends ScalaTestWithActorTestKit with LogCapt sinkProbe.request(10) sinkProbe.expectNextN(10) + var records = probe.receiveMessages(10) withClue("checking: processed records contain 5 from each partition") { - records.toSeq.length shouldBe 10 + records.length shouldBe 10 records.count(_.partition() == tp0.partition()) shouldBe 5 records.count(_.partition() == tp1.partition()) shouldBe 5 } + // because source push to handle(probe) before sinkProbe request pull, it made probe cache random one record + val eagerMessage = probe.receiveMessage() + records = records ++ Set(eagerMessage) + val tp0Received = records.count(_.partition() == tp0.partition()) + val tp0Expect = totalPerPartition - tp0Received + // assign only tp0 to this projection provider.partitionHandler.onAssign(Set(tp0), null) provider.partitionHandler.onRevoke(Set(tp1), null) @@ -128,10 +127,11 @@ class KafkaSourceProviderImplSpec extends ScalaTestWithActorTestKit with LogCapt // only records from partition 0 should remain, because the rest were filtered sinkProbe.request(5) sinkProbe.expectNextN(5) + records = probe.receiveMessages(tp0Expect) withClue("checking: after rebalance processed records should only have records from partition 0") { - records.count(_.partition() == tp0.partition()) shouldBe 10 - records.count(_.partition() == tp1.partition()) shouldBe 5 + records.count(_.partition() == tp0.partition()) shouldBe tp0Expect + records.count(_.partition() == tp1.partition()) shouldBe 0 } } }