Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package cats.effect.benchmarks

import cats.effect.IO
import cats.effect.syntax.all._
import cats.effect.unsafe.implicits.global
import cats.implicits.{catsSyntaxParallelTraverse1, toTraverseOps}

import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra.Blackhole

import scala.concurrent.duration._

import java.util.concurrent.TimeUnit

/**
Expand Down Expand Up @@ -55,6 +58,24 @@ class ParallelBenchmark {
def parTraverse(): Unit =
1.to(size).toList.parTraverse(_ => IO(Blackhole.consumeCPU(cpuTokens))).void.unsafeRunSync()

@Benchmark
def parTraverseN(): Unit =
1.to(size)
.toList
.parTraverseN(size / 100)(_ => IO(Blackhole.consumeCPU(cpuTokens)))
.void
.unsafeRunSync()

@Benchmark
def parTraverseNCancel(): Unit = {
val e = new RuntimeException
val test = 1.to(size * 100).toList.parTraverseN(size / 100) { _ =>
IO.sleep(100.millis) *> IO.raiseError(e)
}

test.attempt.void.unsafeRunSync()
}

@Benchmark
def traverse(): Unit =
1.to(size).toList.traverse(_ => IO(Blackhole.consumeCPU(cpuTokens))).void.unsafeRunSync()
Expand Down
173 changes: 165 additions & 8 deletions kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,186 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {
parTraverseN_(n)(tma)(identity)

/**
* Like `Parallel.parTraverse`, but limits the degree of parallelism. Note that the semantics
* of this operation aim to maximise fairness: when a spot to execute becomes available, every
* task has a chance to claim it, and not only the next `n` tasks in `ta`
* Like `Parallel.parTraverse`, but limits the degree of parallelism. The semantics of this
* function are ordered based on the `Traverse`. The first ''n'' actions will be started
* first, with subsequent actions starting in order as each one completes. Actions which are
* reached earlier in `traverse` order will be started slightly sooner than later actions, in
* a non-blocking fashion. Any errors or self-cancelation will immediately abort the sequence.
* If multiple actions produce errors simultaneously, one of them will be nondeterministically
* selected for production. If all actions succeed, their results are returned in the same
* order as their corresponding inputs, regardless of the order in which they executed.
*
* The `f` function is run as part of running the action: in parallel and subject to the
* limit.
*/
def parTraverseN[T[_]: Traverse, A, B](n: Int)(ta: T[A])(f: A => F[B]): F[T[B]] = {
require(n >= 1, s"Concurrency limit should be at least 1, was: $n")

implicit val F: GenConcurrent[F, E] = this

MiniSemaphore[F](n).flatMap { sem => ta.parTraverse { a => sem.withPermit(f(a)) } }
F.deferred[Option[E]] flatMap { preempt =>
F.ref[Set[(Fiber[F, ?, ?], Deferred[F, Outcome[F, E, B]])]](Set()) flatMap {
supervision =>
// has to be done in parallel to avoid head of line issues
def cancelAll(cause: Option[E]) = supervision.get flatMap { states =>
val causeOC: Outcome[F, E, B] = cause match {
case Some(e) => Outcome.Errored(e)
case None => Outcome.Canceled()
}

states.toList parTraverse_ {
case (fiber, result) =>
result.complete(causeOC) *> fiber.cancel
}
}

MiniSemaphore[F](n) flatMap { sem =>
val results = ta traverse { a =>
preempt.tryGet flatMap {
case Some(Some(e)) => F.pure(F.raiseError[B](e))
case Some(None) => F.pure(F.canceled *> F.never[B])

case None =>
F.uncancelable { poll =>
F.deferred[Outcome[F, E, B]] flatMap { result =>
// the laziness is a poor mans defer; this ensures the f gets pushed to the fiber
val action = poll(sem.acquire) *> (F.unit >> f(a))
.guaranteeCase { oc =>
val completion = oc match {
case Outcome.Succeeded(_) =>
preempt.tryGet flatMap {
case Some(Some(e)) =>
result.complete(Outcome.Errored(e))

case Some(None) =>
result.complete(Outcome.Canceled())

case None =>
result.complete(oc)
}

case Outcome.Errored(e) =>
preempt
.complete(Some(e))
.ifM(
result.complete(oc) <* cancelAll(Some(e)).start,
false.pure[F])

case Outcome.Canceled() =>
preempt
.complete(None)
.ifM(
result.complete(oc) <* cancelAll(None).start,
false.pure[F])
}

completion *> sem.release
}
.void
.voidError
.start

action flatMap { fiber =>
supervision.update(_ + ((fiber, result))) map { _ =>
result
.get
.flatMap(_.embed(F.canceled *> F.never))
.onCancel(fiber.cancel)
.guarantee(supervision.update(_ - ((fiber, result))))
}
}
}
}
}
}

results.flatMap(_.sequence).onCancel(cancelAll(None))
}
}
}
}

/**
* Like `Parallel.parTraverse_`, but limits the degree of parallelism. Note that the semantics
* of this operation aim to maximise fairness: when a spot to execute becomes available, every
* task has a chance to claim it, and not only the next `n` tasks in `ta`
* Like `Parallel.parTraverse_`, but limits the degree of parallelism. The semantics of this
* function are ordered based on the `Foldable`. The first ''n'' actions will be started
* first, with subsequent actions starting in order as each one completes. Actions which are
* reached earlier in `foldLeftM` order will be started slightly sooner than later actions, in
* a non-blocking fashion. Any errors or self-cancelation will immediately abort the sequence.
* If multiple actions produce errors simultaneously, one of them will be nondeterministically
* selected for production.
*
* The `f` function is run as part of running the action: in parallel and subject to the
* limit.
*/
def parTraverseN_[T[_]: Foldable, A, B](n: Int)(ta: T[A])(f: A => F[B]): F[Unit] = {
require(n >= 1, s"Concurrency limit should be at least 1, was: $n")

implicit val F: GenConcurrent[F, E] = this

MiniSemaphore[F](n).flatMap { sem => ta.parTraverse_ { a => sem.withPermit(f(a)) } }
F.deferred[Option[E]] flatMap { preempt =>
F.ref[List[Fiber[F, ?, ?]]](Nil) flatMap { supervision =>
MiniSemaphore[F](n) flatMap { sem =>
val cancelAll = supervision.get.flatMap(_.parTraverse_(_.cancel))

// doesn't complete until every fiber has been at least *started*
val startAll = ta traverse_ { a =>
// first check to see if any of the effects have errored out
// don't bother starting new things if that happens
preempt.tryGet flatMap {
case Some(Some(e)) =>
F.raiseError[Unit](e)

case Some(None) =>
F.canceled

case None =>
F.uncancelable { poll =>
// if the effect produces a non-success, race to kill all the rest
// the laziness is a poor mans defer; this ensures the f gets pushed to the fiber
val wrapped = (F.unit >> f(a)) guaranteeCase {
case Outcome.Succeeded(_) =>
F.unit

case Outcome.Errored(e) =>
preempt.complete(Some(e)).void

case Outcome.Canceled() =>
preempt.complete(None).void
}

val suppressed = wrapped.void.voidError.guarantee(sem.release)

poll(sem.acquire) *> suppressed.start flatMap { fiber =>
// supervision is handled very differently here: we never remove from the set
supervision.update(fiber :: _)
}
}
}
}

// we only run this when we know that supervision is full
val awaitAll = preempt.tryGet flatMap {
case Some(_) => F.unit
case None =>
F.race(preempt.get.void, supervision.get.flatMap(_.traverse_(_.join.void))).void
}

// if we hit an error or self-cancelation in any effect, resurface it here
val resurface = preempt.tryGet flatMap {
case Some(Some(e)) => F.raiseError[Unit](e)
case Some(None) => F.canceled
case None => F.unit
}

val work = (startAll *> awaitAll) guaranteeCase {
case Outcome.Succeeded(_) => F.unit
case Outcome.Errored(_) | Outcome.Canceled() => preempt.complete(None) *> cancelAll
}

work *> resurface
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor thing, but this means we can get cancelled when we're already done. If work completes successfully, i.e., we're literally done, there is no reason really, to observe a cancellation. (It feels a little bit like the timeout changes.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also a correctness issue actually because we could lose data. I'll correct it.

}
}
}
}

override def racePair[A, B](fa: F[A], fb: F[B])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ import scala.collection.immutable.{Queue => ScalaQueue}
* A cut-down version of semaphore used to implement parTraverseN
*/
private[kernel] abstract class MiniSemaphore[F[_]] extends Serializable {
def acquire: F[Unit]
def release: F[Unit]

/**
* Sequence an action while holding a permit
*/
def withPermit[A](fa: F[A]): F[A]
}

Expand Down
21 changes: 21 additions & 0 deletions tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package cats.effect

import cats.effect.std.Semaphore
import cats.effect.syntax.all._
import cats.effect.unsafe.{
IORuntime,
IORuntimeConfig,
Expand Down Expand Up @@ -812,6 +813,26 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala
}
}

"parTraverseN" >> {
"short-circuit on error" in real {
case object TestException extends RuntimeException
val target = 0.until(100000).toList
val test = target.parTraverseN(2)(_ => IO.raiseError(TestException))

test.attempt.as(ok).timeoutTo(500.millis, IO(false must beTrue))
}
}

"parTraverseN_" >> {
"short-circuit on error" in real {
case object TestException extends RuntimeException
val target = 0.until(100000).toList
val test = target.parTraverseN_(2)(_ => IO.raiseError(TestException))

test.attempt.as(ok).timeoutTo(500.millis, IO(false must beTrue))
}
}

if (javaMajorVersion >= 21)
"block in-place on virtual threads" in real {
val loomExec = classOf[Executors]
Expand Down
Loading
Loading