Skip to content

Commit

Permalink
Add SCollection reify methods (#2987)
Browse files Browse the repository at this point in the history
  • Loading branch information
regadas authored Jun 1, 2020
1 parent f1563f2 commit 4a20995
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1155,4 +1155,19 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
.map(_._2.toMap)
)
.asSingletonSideInput(Map.empty[K, Iterable[V]])

/**
* Returns an [[SCollection]] consisting of a single `Map[K, V]` element.
*/
def reifyAsMapInGlobalWindow(implicit ck: Coder[K], cv: Coder[V]): SCollection[Map[K, V]] =
self.reifyInGlobalWindow(_.asMapSideInput)

/**
* Returns an [[SCollection]] consisting of a single `Map[K, Iterable[V]]` element.
*/
def reifyAsMultiMapInGlobalWindow(implicit
ck: Coder[K],
cv: Coder[V]
): SCollection[Map[K, Iterable[V]]] =
self.reifyInGlobalWindow(_.asMultiMapSideInput)
}
75 changes: 75 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,81 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
@deprecated("Use readAllAsBytes instead", "0.8.1")
def readAllBytes(implicit ev: T <:< String): SCollection[Array[Byte]] = readFilesAsBytes

/**
* Pairs each element with the value of the provided [[SideInput]] in the element's window.
*
* Reify as List:
* {{{
* val other: SCollection[Int] = sc.parallelize(Seq(1))
* val coll: SCollection[(Int, Seq[Int])] =
* sc.parallelize(Seq(1, 2))
* .reifySideInputAsValues(other.asListSideInput)
* }}}
*
* Reify as Iterable:
* {{{
* val other: SCollection[Int] = sc.parallelize(Seq(1))
* val coll: SCollection[(Int, Iterable[Int])] =
* sc.parallelize(Seq(1, 2))
* .reifySideInputAsValues(other.asIterableSideInput)
* }}}
*
* Reify as Map:
* {{{
* val other: SCollection[(Int, Int)] = sc.parallelize(Seq((1, 1)))
* val coll: SCollection[(Int, Map[Int, Int])] =
* sc.parallelize(Seq(1, 2))
* .reifySideInputAsValues(other.asMapSideInput)
* }}}
*
* Reify as Multimap:
* {{{
* val other: SCollection[(Int, Int)] = sc.parallelize(Seq((1, 1)))
* val coll: SCollection[(Int, Map[Int, Iterable[Int]])] =
* sc.parallelize(Seq(1, 2))
* .reifySideInputAsValues(other.asMultiMapSideInput)
* }}}
*/
def reifySideInputAsValues[U: Coder](side: SideInput[U]): SCollection[(T, U)] = {
implicit val tc = Coder.beam(internal.getCoder)
this.transform(_.withSideInputs(side).map((t, s) => (t, s(side))).toSCollection)
}

/**
* Returns an [[SCollection]] consisting of a single `Seq[T]` element.
*/
def reifyAsListInGlobalWindow(implicit coder: Coder[T]): SCollection[Seq[T]] =
reifyInGlobalWindow(_.asListSideInput)

/**
* Returns an [[SCollection]] consisting of a single `Iterable[T]` element.
*/
def reifyAsIterableInGlobalWindow(implicit coder: Coder[T]): SCollection[Iterable[T]] =
reifyInGlobalWindow(_.asIterableSideInput)

/**
* Returns an [[SCollection]] consisting of a single element, containing the value of the given
* side input in the global window.
*
* Reify as List:
* {{{
* val coll: SCollection[Seq[Int]] =
* sc.parallelize(Seq(1, 2)).reifyInGlobalWindow(_.asListSideInput)
* }}}
*
* Can be used to replace patterns like:
* {{{
* val coll: SCollection[Iterable[Int]] = sc.parallelize(Seq(1, 2)).groupBy(_ => ())
* }}}
* where you want to actually get an empty [[Iterable]] even if no data is present.
*/
private[scio] def reifyInGlobalWindow[U: Coder](
view: SCollection[T] => SideInput[U]
): SCollection[U] =
this.transform(coll =>
context.parallelize[Unit](Seq(())).reifySideInputAsValues(view(coll)).values
)

// =======================================================================
// Write operations
// =======================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers


object BigQueryIT {
@AvroType.fromSchema("""{
| "type":"record",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package com.spotify.scio.extra

import java.util
import java.lang.{Iterable => JIterable}
import java.util.{UUID, List => JList}
import java.lang.Math.floorMod
import java.util.{UUID, List => JList, Iterator => JIterator}

import com.spotify.scio.ScioContext
import com.spotify.scio.annotations.experimental
import com.spotify.scio.coders.{Coder, CoderMaterializer}
import com.spotify.scio.coders.Coder
import com.spotify.scio.extra.sparkey.instances.{
CachedStringSparkeyReader,
SparkeyReaderInstances,
Expand All @@ -32,14 +31,13 @@ import com.spotify.scio.extra.sparkey.instances.{
import com.spotify.scio.util.Cache
import com.spotify.scio.values.{SCollection, SideInput}
import com.spotify.sparkey.{IndexHeader, LogHeader, SparkeyReader}
import org.apache.beam.sdk.transforms.{DoFn, Reify, View}
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.transforms.{DoFn, View}
import org.apache.beam.sdk.values.PCollectionView
import org.slf4j.LoggerFactory

import scala.jdk.CollectionConverters._
import scala.util.hashing.MurmurHash3
import java.lang.Math.floorMod
import org.apache.beam.sdk.io.FileSystems

/**
* Main package for Sparkey side input APIs. Import all.
Expand Down Expand Up @@ -186,7 +184,7 @@ package object sparkey extends SparkeyReaderInstances {
private def writeToSparkey[K, V](
uri: SparkeyUri,
maxMemoryUsage: Long,
elements: JIterable[(K, V)]
elements: Iterable[(K, V)]
)(implicit w: SparkeyWritable[K, V], koder: Coder[K], voder: Coder[V]): SparkeyUri = {
val writer = new SparkeyWriter(uri, maxMemoryUsage)
val it = elements.iterator
Expand Down Expand Up @@ -237,8 +235,6 @@ package object sparkey extends SparkeyReaderInstances {

logger.info(s"Saving as Sparkey with $numShards shards: $basePath")

val uriCoder = CoderMaterializer.beam(self.context, Coder[JList[SparkeyUri]])

self.transform { collection =>
val shards = collection
.groupBy { case (k, _) => floorMod(w.shardHash(k), numShards).toShort }
Expand All @@ -247,7 +243,7 @@ package object sparkey extends SparkeyReaderInstances {
shard -> writeToSparkey(
uri.sparkeyUriForShard(shard, numShards),
maxMemoryUsage,
xs.asJava
xs
)
}

Expand All @@ -263,17 +259,13 @@ package object sparkey extends SparkeyReaderInstances {
writeToSparkey(
uri.sparkeyUriForShard(shard, numShards),
maxMemoryUsage,
Iterable.empty[(K, V)].asJava
Iterable.empty[(K, V)]
)
)
}
.toSCollection

uris.context
.wrap {
val view = uris.applyInternal(View.asList())
uris.internal.getPipeline.apply(Reify.viewInGlobalWindow(view, uriCoder))
}
uris.reifyAsListInGlobalWindow
.map { _ =>
if (numShards == 1) {
val src = FileSystems
Expand Down Expand Up @@ -495,7 +487,7 @@ package object sparkey extends SparkeyReaderInstances {

override def close(): Unit = sparkeys.values.foreach(_.close())

override def iterator(): util.Iterator[SparkeyReader.Entry] =
override def iterator(): JIterator[SparkeyReader.Entry] =
sparkeys.values.map(_.iterator.asScala).reduce(_ ++ _).asJava
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,4 +698,51 @@ class SCollectionTest extends PipelineSpec {
count should containSingleValue(0L)
}
}

it should "reify as List" in {
runWithContext { sc =>
val other = sc.parallelize(Seq(1))
val coll = sc.parallelize(Seq(1, 2)).reifySideInputAsValues(other.asListSideInput)
coll should containInAnyOrder(Seq((1, Seq(1)), (2, Seq(1))))
}
}

it should "reify as Iterable" in {
runWithContext { sc =>
val other = sc.parallelize(Seq(1))
val coll = sc.parallelize(Seq(1, 2)).reifySideInputAsValues(other.asIterableSideInput)
coll should containInAnyOrder(Seq((1, Iterable(1)), (2, Iterable(1))))
}
}

it should "reify as Map" in {
runWithContext { sc =>
val other = sc.parallelize(Seq((1, 1)))
val coll = sc.parallelize(Seq(1, 2)).reifySideInputAsValues(other.asMapSideInput)
coll should containInAnyOrder(Seq((1, Map(1 -> 1)), (2, Map(1 -> 1))))
}
}

it should "reify as Multimap" in {
runWithContext { sc =>
val other = sc.parallelize(Seq((1, 1)))
val coll = sc.parallelize(Seq(1, 2)).reifySideInputAsValues(other.asMultiMapSideInput)
coll should containInAnyOrder(Seq((1, Map(1 -> Iterable(1))), (2, Map(1 -> Iterable(1)))))
}
}

it should "reify in Golbal Window as List" in {
runWithContext { sc =>
val coll = sc.parallelize(Seq(1)).reifyAsListInGlobalWindow
coll should containInAnyOrder(Seq(Seq(1)))
}
}

it should "reify empty in Golbal Window as List" in {
runWithContext { sc =>
val coll = sc.empty[Int].reifyAsListInGlobalWindow
coll should containInAnyOrder(Seq(Seq.empty[Int]))
}
}

}

0 comments on commit 4a20995

Please sign in to comment.