diff --git a/algebird-spark/src/main/scala/com/twitter/algebird/spark/AlgebirdDataset.scala b/algebird-spark/src/main/scala/com/twitter/algebird/spark/AlgebirdDataset.scala new file mode 100644 index 000000000..00b8aa949 --- /dev/null +++ b/algebird-spark/src/main/scala/com/twitter/algebird/spark/AlgebirdDataset.scala @@ -0,0 +1,75 @@ +package com.twitter.algebird + +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.Encoder + +class AlgebirdDataset[T](val ds: Dataset[T]) extends AnyVal { + + def aggregateOption[B: Encoder, C]( + agg: Aggregator[T, B, C] + ): Option[C] = { + val pr = ds.mapPartitions(data => + if (data.isEmpty) Iterator.empty + else { + val b = agg.prepare(data.next) + Iterator(agg.appendAll(b, data)) + } + ) + + val results = pr + .repartition(1) + .mapPartitions(it => agg.semigroup.sumOption(it).toIterator) + .collect + + if (results.isEmpty) None + else Some(agg.present(results.head)) + } + + def aggregate[B: Encoder, C](agg: Aggregator[T, B, C]): C = + (aggregateOption[B, C](agg), agg.semigroup) match { + case (Some(c), _) => c + case (None, m: Monoid[B]) => agg.present(m.zero) + case (None, _) => None.get // no such element + } + + def aggregateByKey[K: Encoder, V1: Encoder, U: Encoder, V2: Encoder]( + agg: Aggregator[V1, U, V2] + )( + implicit ev: T <:< (K, V1), + enc1: Encoder[(K, U)], + enc2: Encoder[(K, V2)] + ): Dataset[(K, V2)] = + keyed + .mapPartitions(it => it.map { case (k, v) => (k, agg.prepare(v)) }) + .groupByKey(_._1) + .reduceGroups((a: (K, U), b: (K, U)) => (a._1, agg.reduce(a._2, b._2))) + .map { case (k, (_, v)) => (k, agg.present(v)) } + + private def keyed[K, V](implicit ev: T <:< (K, V)): Dataset[(K, V)] = + ds.asInstanceOf[Dataset[(K, V)]] + + def sumByKey[K: Encoder, V: Semigroup: Encoder]()( + implicit ev: T <:< (K, V), + enc: Encoder[(K, V)] + ): Dataset[(K, V)] = + keyed + .groupByKey(_._1) + .reduceGroups((a: (K, V), b: (K, V)) => (a._1, implicitly[Semigroup[V]].plus(a._2, b._2))) + .map { case (k, (_, v)) => (k, v) } + + def sumOption(implicit sg: Semigroup[T], enc1: Encoder[T]): Option[T] = { + val partialReduce: Dataset[T] = + ds.mapPartitions(itT => sg.sumOption(itT).toIterator) + + val results = partialReduce + .repartition(1) + .mapPartitions(it => sg.sumOption(it).toIterator) + .collect + + if (results.isEmpty) None + else Some(results.head) + } + + def sum(implicit mon: Monoid[T], enc1: Encoder[T]): T = + sumOption.getOrElse(mon.zero) +} diff --git a/algebird-spark/src/main/scala/com/twitter/algebird/spark/implicits/package.scala b/algebird-spark/src/main/scala/com/twitter/algebird/spark/implicits/package.scala new file mode 100644 index 000000000..bc53d0da8 --- /dev/null +++ b/algebird-spark/src/main/scala/com/twitter/algebird/spark/implicits/package.scala @@ -0,0 +1,29 @@ +package com.twitter.algebird.spark + +package object implicits { + +import com.twitter.algebird.BF + +import com.twitter.algebird.BFZero + +import java.util.PriorityQueue + +import com.twitter.algebird.BloomFilterMonoid + +import org.apache.spark.sql.Encoder + + import scala.reflect.ClassTag + implicit def kryoPriorityQueueEncoder[A](implicit ct: ClassTag[PriorityQueue[A]]): Encoder[PriorityQueue[A]] = + org.apache.spark.sql.Encoders.kryo[PriorityQueue[A]](ct) + + implicit def kryoTuplePriorityQueueEncoder[A, B](implicit ct: ClassTag[(B, PriorityQueue[A])]): Encoder[(B, PriorityQueue[A])] = + org.apache.spark.sql.Encoders.kryo[(B, PriorityQueue[A])](ct) + + implicit def kryoBloomFilterMonoidEncoder[A](implicit ct: ClassTag[BloomFilterMonoid[A]]): Encoder[BloomFilterMonoid[A]] = + org.apache.spark.sql.Encoders.kryo[BloomFilterMonoid[A]](ct) + + implicit def kryoBFZeroEncoder[A](implicit ct: ClassTag[BFZero[A]]): Encoder[BFZero[A]] = + org.apache.spark.sql.Encoders.kryo[BFZero[A]](ct) + + implicit def kryoBFEncoder[A](implicit ct: ClassTag[BF[A]]): Encoder[BF[A]] = org.apache.spark.sql.Encoders.kryo[BF[A]](ct) +} diff --git a/algebird-spark/src/main/scala/com/twitter/algebird/spark/package.scala b/algebird-spark/src/main/scala/com/twitter/algebird/spark/package.scala index 05bb1fb3c..aaa5a7138 100644 --- a/algebird-spark/src/main/scala/com/twitter/algebird/spark/package.scala +++ b/algebird-spark/src/main/scala/com/twitter/algebird/spark/package.scala @@ -10,16 +10,22 @@ import scala.reflect.ClassTag */ package object spark { +import org.apache.spark.sql.Dataset + /** * spark exposes an Aggregator type, so this is here to avoid shadowing */ type AlgebirdAggregator[A, B, C] = Aggregator[A, B, C] val AlgebirdAggregator = Aggregator - implicit class ToAlgebird[T](val rdd: RDD[T]) extends AnyVal { + implicit class ToAlgebirdRDD[T](val rdd: RDD[T]) extends AnyVal { def algebird: AlgebirdRDD[T] = new AlgebirdRDD[T](rdd) } + implicit class ToAlgebirdDataset[T](val ds: Dataset[T]) extends AnyVal { + def algebird: AlgebirdDataset[T] = new AlgebirdDataset[T](ds) + } + def rddMonoid[T: ClassTag](sc: SparkContext): Monoid[RDD[T]] = new Monoid[RDD[T]] { def zero = sc.emptyRDD[T] override def isNonZero(s: RDD[T]) = s.isEmpty @@ -31,4 +37,6 @@ package object spark { } // We should be able to make an Applicative[RDD] except that map needs an implicit ClassTag // which breaks the Applicative signature. I don't see a way around that. + + } diff --git a/algebird-spark/src/test/scala/com/twitter/algebird/spark/AlgebirdDatasetTests.scala b/algebird-spark/src/test/scala/com/twitter/algebird/spark/AlgebirdDatasetTests.scala new file mode 100644 index 000000000..58b82f5ed --- /dev/null +++ b/algebird-spark/src/test/scala/com/twitter/algebird/spark/AlgebirdDatasetTests.scala @@ -0,0 +1,117 @@ +package com.twitter.algebird.spark + +import com.twitter.algebird.{MapAlgebra, Monoid, Semigroup} +import org.scalatest._ +import org.scalatest.funsuite.AnyFunSuite +import com.twitter.algebird.Min +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.Dataset +import com.twitter.algebird.BloomFilter +import com.twitter.algebird.BloomFilterAggregator +import com.twitter.algebird.Hash128 + +package test { + // not needed in the algebird package, just testing the API + + object DatasetTest { + def sum[T: Monoid: Encoder](r: Dataset[T]) = r.algebird.sum + } +} + +class AlgebirdDatasetTest extends AnyFunSuite with BeforeAndAfter { + + private var spark: SparkSession = _ + + before { + spark = SparkSession.builder().master("local").getOrCreate() + } + + after { + // try spark.stop() + // catch { + // case t: Throwable => () + // } + } + + // Why does scala.math.Equiv suck so much. + implicit def optEq[V](implicit eq: Equiv[V]): Equiv[Option[V]] = Equiv.fromFunction[Option[V]] { (o1, o2) => + (o1, o2) match { + case (Some(v1), Some(v2)) => eq.equiv(v1, v2) + case (None, None) => true + case _ => false + } + } + + def equiv[V](a: V, b: V)(implicit eq: Equiv[V]): Boolean = eq.equiv(a, b) + def assertEq[V: Equiv](a: V, b: V): Unit = assert(equiv(a, b)) + + def aggregate[T: Encoder, U: Encoder, V: Equiv]( + s: Seq[T], + agg: AlgebirdAggregator[T, U, V] + ): Unit = + assertEq(spark.createDataset(s).algebird.aggregate(agg), agg(s)) + + def aggregateByKey[K: Encoder, T: Encoder, U: Encoder, V: Equiv: Encoder]( + s: Seq[(K, T)], + agg: AlgebirdAggregator[T, U, V] + )(implicit enc1: Encoder[(K, V)], enc2: Encoder[(K, T)], enc3: Encoder[(K, U)]): Unit = { + val resMap = spark.createDataset(s).algebird.aggregateByKey[K, T, U, V](agg).collect.toMap + implicit val sg = agg.semigroup + val algMap = MapAlgebra.sumByKey(s.map { case (k, t) => k -> agg.prepare(t) }).mapValues(agg.present) + s.map(_._1).toSet.foreach { k: K => assertEq(resMap.get(k), algMap.get(k)) } + } + + def sumOption[T: Encoder: Equiv: Semigroup](s: Seq[T]): Unit = + assertEq(spark.createDataset(s).algebird.sumOption, Semigroup.sumOption(s)) + + def sumByKey[K: Encoder, V: Encoder: Semigroup: Equiv]( + s: Seq[(K, V)] + )(implicit enc: Encoder[(K, V)]): Unit = { + val resMap = spark.createDataset(s).algebird.sumByKey[K, V].collect.toMap + val algMap = MapAlgebra.sumByKey(s) + s.map(_._1).toSet.foreach { k: K => assertEq(resMap.get(k), algMap.get(k)) } + } + + import com.twitter.algebird.spark.implicits._ + + implicit val hash = Hash128.intHash + /** + * These tests almost always timeout on Travis. Leaving the + * above to at least check compilation + */ + test("aggregate") { + val sparkImplicits = spark.implicits + import sparkImplicits._ + + aggregate(0 to 1000, AlgebirdAggregator.fromSemigroup[Int]) + aggregate(0 to 1000, AlgebirdAggregator.min[Int]) + aggregate(0 to 1000, AlgebirdAggregator.sortedTake[Int](3)) + aggregate(0 to 1000, BloomFilterAggregator(1000,1000)) + + } + test("sumOption") { + val sparkImplicits = spark.implicits + import sparkImplicits._ + + sumOption(0 to 1000) + sumOption((0 to 1000).map(Min(_))) + sumOption((0 to 1000).map(x => (x, x % 3))) + } + test("aggregateByKey") { + val sparkImplicits = spark.implicits + import sparkImplicits._ + + aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.fromSemigroup[Int]) + aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.min[Int]) + aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.sortedTake[Int](3)) + } + test("sumByKey") { + val sparkImplicits = spark.implicits + import sparkImplicits._ + + sumByKey((0 to 1000).map(k => (k % 3, k))) + sumByKey((0 to 1000).map(k => (k % 3, Option(k)))) + sumByKey((0 to 1000).map(k => (k % 3, Min(k)))) + } +} diff --git a/algebird-spark/src/test/scala/com/twitter/algebird/spark/AlgebirdRDDTests.scala b/algebird-spark/src/test/scala/com/twitter/algebird/spark/AlgebirdRDDTests.scala index b6de4bdf6..878b573a7 100644 --- a/algebird-spark/src/test/scala/com/twitter/algebird/spark/AlgebirdRDDTests.scala +++ b/algebird-spark/src/test/scala/com/twitter/algebird/spark/AlgebirdRDDTests.scala @@ -6,11 +6,13 @@ import org.apache.spark.rdd._ import org.scalatest._ import scala.reflect.ClassTag import org.scalatest.funsuite.AnyFunSuite +import com.twitter.algebird.Min +import org.apache.spark.sql.SparkSession package test { // not needed in the algebird package, just testing the API - import com.twitter.algebird.spark.ToAlgebird - object Test { + import com.twitter.algebird.spark.ToAlgebirdRDD + object RDDTest { def sum[T: Monoid: ClassTag](r: RDD[T]) = r.algebird.sum } } @@ -24,11 +26,7 @@ class AlgebirdRDDTest extends AnyFunSuite with BeforeAndAfter { private var sc: SparkContext = _ before { - // val conf = new SparkConf() - // .setMaster(master) - // .setAppName(appName) - - // sc = new SparkContext(conf) + sc = SparkSession.builder().master("local").getOrCreate().sparkContext } after { @@ -76,24 +74,24 @@ class AlgebirdRDDTest extends AnyFunSuite with BeforeAndAfter { * These tests almost always timeout on Travis. Leaving the * above to at least check compilation */ - // test("aggregate") { - // aggregate(0 to 1000, AlgebirdAggregator.fromSemigroup[Int]) - // aggregate(0 to 1000, AlgebirdAggregator.min[Int]) - // aggregate(0 to 1000, AlgebirdAggregator.sortedTake[Int](3)) - // } - // test("sumOption") { - // sumOption(0 to 1000) - // sumOption((0 to 1000).map(Min(_))) - // sumOption((0 to 1000).map(x => (x, x % 3))) - // } - // test("aggregateByKey") { - // aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.fromSemigroup[Int]) - // aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.min[Int]) - // aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.sortedTake[Int](3)) - // } - // test("sumByKey") { - // sumByKey((0 to 1000).map(k => (k % 3, k))) - // sumByKey((0 to 1000).map(k => (k % 3, Option(k)))) - // sumByKey((0 to 1000).map(k => (k % 3, Min(k)))) - // } + test("aggregate") { + aggregate(0 to 1000, AlgebirdAggregator.fromSemigroup[Int]) + aggregate(0 to 1000, AlgebirdAggregator.min[Int]) + aggregate(0 to 1000, AlgebirdAggregator.sortedTake[Int](3)) + } + test("sumOption") { + sumOption(0 to 1000) + sumOption((0 to 1000).map(Min(_))) + sumOption((0 to 1000).map(x => (x, x % 3))) + } + test("aggregateByKey") { + aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.fromSemigroup[Int]) + aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.min[Int]) + aggregateByKey((0 to 1000).map(k => (k % 3, k)), AlgebirdAggregator.sortedTake[Int](3)) + } + test("sumByKey") { + sumByKey((0 to 1000).map(k => (k % 3, k))) + sumByKey((0 to 1000).map(k => (k % 3, Option(k)))) + sumByKey((0 to 1000).map(k => (k % 3, Min(k)))) + } } diff --git a/build.sbt b/build.sbt index 64baf2688..db436cf98 100644 --- a/build.sbt +++ b/build.sbt @@ -309,7 +309,10 @@ lazy val algebirdBijection = module("bijection") lazy val algebirdSpark = module("spark") .settings( - libraryDependencies += "org.apache.spark" %% "spark-core" % sparkVersion % "provided", + libraryDependencies ++= Seq( + "org.apache.spark" %% "spark-core" % sparkVersion % "provided", + "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" + ), scalacOptions := scalacOptions.value .filterNot( _.contains("inline")