diff --git a/benchmarks/src/main/scala/cats/effect/benchmarks/ParallelBenchmark.scala b/benchmarks/src/main/scala/cats/effect/benchmarks/ParallelBenchmark.scala index 0e46ee7fdf..33f093e028 100644 --- a/benchmarks/src/main/scala/cats/effect/benchmarks/ParallelBenchmark.scala +++ b/benchmarks/src/main/scala/cats/effect/benchmarks/ParallelBenchmark.scala @@ -17,12 +17,15 @@ package cats.effect.benchmarks import cats.effect.IO +import cats.effect.syntax.all._ import cats.effect.unsafe.implicits.global import cats.implicits.{catsSyntaxParallelTraverse1, toTraverseOps} import org.openjdk.jmh.annotations._ import org.openjdk.jmh.infra.Blackhole +import scala.concurrent.duration._ + import java.util.concurrent.TimeUnit /** @@ -55,6 +58,24 @@ class ParallelBenchmark { def parTraverse(): Unit = 1.to(size).toList.parTraverse(_ => IO(Blackhole.consumeCPU(cpuTokens))).void.unsafeRunSync() + @Benchmark + def parTraverseN(): Unit = + 1.to(size) + .toList + .parTraverseN(size / 100)(_ => IO(Blackhole.consumeCPU(cpuTokens))) + .void + .unsafeRunSync() + + @Benchmark + def parTraverseNCancel(): Unit = { + val e = new RuntimeException + val test = 1.to(size * 100).toList.parTraverseN(size / 100) { _ => + IO.sleep(100.millis) *> IO.raiseError(e) + } + + test.attempt.void.unsafeRunSync() + } + @Benchmark def traverse(): Unit = 1.to(size).toList.traverse(_ => IO(Blackhole.consumeCPU(cpuTokens))).void.unsafeRunSync() diff --git a/kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala b/kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala index 024f640197..6c917e8f5d 100644 --- a/kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala +++ b/kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala @@ -130,29 +130,186 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] { parTraverseN_(n)(tma)(identity) /** - * Like `Parallel.parTraverse`, but limits the degree of parallelism. Note that the semantics - * of this operation aim to maximise fairness: when a spot to execute becomes available, every - * task has a chance to claim it, and not only the next `n` tasks in `ta` + * Like `Parallel.parTraverse`, but limits the degree of parallelism. The semantics of this + * function are ordered based on the `Traverse`. The first ''n'' actions will be started + * first, with subsequent actions starting in order as each one completes. Actions which are + * reached earlier in `traverse` order will be started slightly sooner than later actions, in + * a non-blocking fashion. Any errors or self-cancelation will immediately abort the sequence. + * If multiple actions produce errors simultaneously, one of them will be nondeterministically + * selected for production. If all actions succeed, their results are returned in the same + * order as their corresponding inputs, regardless of the order in which they executed. + * + * The `f` function is run as part of running the action: in parallel and subject to the + * limit. */ def parTraverseN[T[_]: Traverse, A, B](n: Int)(ta: T[A])(f: A => F[B]): F[T[B]] = { require(n >= 1, s"Concurrency limit should be at least 1, was: $n") implicit val F: GenConcurrent[F, E] = this - MiniSemaphore[F](n).flatMap { sem => ta.parTraverse { a => sem.withPermit(f(a)) } } + F.deferred[Option[E]] flatMap { preempt => + F.ref[Set[(Fiber[F, ?, ?], Deferred[F, Outcome[F, E, B]])]](Set()) flatMap { + supervision => + // has to be done in parallel to avoid head of line issues + def cancelAll(cause: Option[E]) = supervision.get flatMap { states => + val causeOC: Outcome[F, E, B] = cause match { + case Some(e) => Outcome.Errored(e) + case None => Outcome.Canceled() + } + + states.toList parTraverse_ { + case (fiber, result) => + result.complete(causeOC) *> fiber.cancel + } + } + + MiniSemaphore[F](n) flatMap { sem => + val results = ta traverse { a => + preempt.tryGet flatMap { + case Some(Some(e)) => F.pure(F.raiseError[B](e)) + case Some(None) => F.pure(F.canceled *> F.never[B]) + + case None => + F.uncancelable { poll => + F.deferred[Outcome[F, E, B]] flatMap { result => + // the laziness is a poor mans defer; this ensures the f gets pushed to the fiber + val action = poll(sem.acquire) *> (F.unit >> f(a)) + .guaranteeCase { oc => + val completion = oc match { + case Outcome.Succeeded(_) => + preempt.tryGet flatMap { + case Some(Some(e)) => + result.complete(Outcome.Errored(e)) + + case Some(None) => + result.complete(Outcome.Canceled()) + + case None => + result.complete(oc) + } + + case Outcome.Errored(e) => + preempt + .complete(Some(e)) + .ifM( + result.complete(oc) <* cancelAll(Some(e)).start, + false.pure[F]) + + case Outcome.Canceled() => + preempt + .complete(None) + .ifM( + result.complete(oc) <* cancelAll(None).start, + false.pure[F]) + } + + completion *> sem.release + } + .void + .voidError + .start + + action flatMap { fiber => + supervision.update(_ + ((fiber, result))) map { _ => + result + .get + .flatMap(_.embed(F.canceled *> F.never)) + .onCancel(fiber.cancel) + .guarantee(supervision.update(_ - ((fiber, result)))) + } + } + } + } + } + } + + results.flatMap(_.sequence).onCancel(cancelAll(None)) + } + } + } } /** - * Like `Parallel.parTraverse_`, but limits the degree of parallelism. Note that the semantics - * of this operation aim to maximise fairness: when a spot to execute becomes available, every - * task has a chance to claim it, and not only the next `n` tasks in `ta` + * Like `Parallel.parTraverse_`, but limits the degree of parallelism. The semantics of this + * function are ordered based on the `Foldable`. The first ''n'' actions will be started + * first, with subsequent actions starting in order as each one completes. Actions which are + * reached earlier in `foldLeftM` order will be started slightly sooner than later actions, in + * a non-blocking fashion. Any errors or self-cancelation will immediately abort the sequence. + * If multiple actions produce errors simultaneously, one of them will be nondeterministically + * selected for production. + * + * The `f` function is run as part of running the action: in parallel and subject to the + * limit. */ def parTraverseN_[T[_]: Foldable, A, B](n: Int)(ta: T[A])(f: A => F[B]): F[Unit] = { require(n >= 1, s"Concurrency limit should be at least 1, was: $n") implicit val F: GenConcurrent[F, E] = this - MiniSemaphore[F](n).flatMap { sem => ta.parTraverse_ { a => sem.withPermit(f(a)) } } + F.deferred[Option[E]] flatMap { preempt => + F.ref[List[Fiber[F, ?, ?]]](Nil) flatMap { supervision => + MiniSemaphore[F](n) flatMap { sem => + val cancelAll = supervision.get.flatMap(_.parTraverse_(_.cancel)) + + // doesn't complete until every fiber has been at least *started* + val startAll = ta traverse_ { a => + // first check to see if any of the effects have errored out + // don't bother starting new things if that happens + preempt.tryGet flatMap { + case Some(Some(e)) => + F.raiseError[Unit](e) + + case Some(None) => + F.canceled + + case None => + F.uncancelable { poll => + // if the effect produces a non-success, race to kill all the rest + // the laziness is a poor mans defer; this ensures the f gets pushed to the fiber + val wrapped = (F.unit >> f(a)) guaranteeCase { + case Outcome.Succeeded(_) => + F.unit + + case Outcome.Errored(e) => + preempt.complete(Some(e)).void + + case Outcome.Canceled() => + preempt.complete(None).void + } + + val suppressed = wrapped.void.voidError.guarantee(sem.release) + + poll(sem.acquire) *> suppressed.start flatMap { fiber => + // supervision is handled very differently here: we never remove from the set + supervision.update(fiber :: _) + } + } + } + } + + // we only run this when we know that supervision is full + val awaitAll = preempt.tryGet flatMap { + case Some(_) => F.unit + case None => + F.race(preempt.get.void, supervision.get.flatMap(_.traverse_(_.join.void))).void + } + + // if we hit an error or self-cancelation in any effect, resurface it here + val resurface = preempt.tryGet flatMap { + case Some(Some(e)) => F.raiseError[Unit](e) + case Some(None) => F.canceled + case None => F.unit + } + + val work = (startAll *> awaitAll) guaranteeCase { + case Outcome.Succeeded(_) => F.unit + case Outcome.Errored(_) | Outcome.Canceled() => preempt.complete(None) *> cancelAll + } + + work *> resurface + } + } + } } override def racePair[A, B](fa: F[A], fb: F[B]) diff --git a/kernel/shared/src/main/scala/cats/effect/kernel/MiniSemaphore.scala b/kernel/shared/src/main/scala/cats/effect/kernel/MiniSemaphore.scala index 41368726ab..2eaa00efed 100644 --- a/kernel/shared/src/main/scala/cats/effect/kernel/MiniSemaphore.scala +++ b/kernel/shared/src/main/scala/cats/effect/kernel/MiniSemaphore.scala @@ -27,10 +27,9 @@ import scala.collection.immutable.{Queue => ScalaQueue} * A cut-down version of semaphore used to implement parTraverseN */ private[kernel] abstract class MiniSemaphore[F[_]] extends Serializable { + def acquire: F[Unit] + def release: F[Unit] - /** - * Sequence an action while holding a permit - */ def withPermit[A](fa: F[A]): F[A] } diff --git a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala index 9dc054dc34..ee99d1ead3 100644 --- a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala +++ b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala @@ -17,6 +17,7 @@ package cats.effect import cats.effect.std.Semaphore +import cats.effect.syntax.all._ import cats.effect.unsafe.{ IORuntime, IORuntimeConfig, @@ -812,6 +813,26 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala } } + "parTraverseN" >> { + "short-circuit on error" in real { + case object TestException extends RuntimeException + val target = 0.until(100000).toList + val test = target.parTraverseN(2)(_ => IO.raiseError(TestException)) + + test.attempt.as(ok).timeoutTo(500.millis, IO(false must beTrue)) + } + } + + "parTraverseN_" >> { + "short-circuit on error" in real { + case object TestException extends RuntimeException + val target = 0.until(100000).toList + val test = target.parTraverseN_(2)(_ => IO.raiseError(TestException)) + + test.attempt.as(ok).timeoutTo(500.millis, IO(false must beTrue)) + } + } + if (javaMajorVersion >= 21) "block in-place on virtual threads" in real { val loomExec = classOf[Executors] diff --git a/tests/shared/src/test/scala/cats/effect/IOSpec.scala b/tests/shared/src/test/scala/cats/effect/IOSpec.scala index b731aa4575..a2eeb2163e 100644 --- a/tests/shared/src/test/scala/cats/effect/IOSpec.scala +++ b/tests/shared/src/test/scala/cats/effect/IOSpec.scala @@ -1614,6 +1614,170 @@ class IOSpec extends BaseSpec with Discipline with IOPlatformSpecification { p must completeAs(true) } + "run finalizers when canceled" in ticked { implicit ticker => + val p = for { + r <- IO.ref(0) + + /* + * The exact series of steps here is: + * + * List(IO.never.onCancel, IO.unit, IO.never.onCancel) + * + * This is significant because we're limiting the parallelism to + * 2, meaning that we will hit a wall after IO.unit. HOWEVER, + * IO.unit completes immediately, so this test not only checks + * cancelation, it also tests that we move onto the third item + * after the second one completes even while the first is blocked. + * In other words, it's testing both cancelation and head of line + * behavior. + */ + f <- List(1, 2, 3) + .parTraverseN(2) { i => + if (i == 2) IO.unit + else IO.never.onCancel(r.update(_ + 1)) + } + .start + + _ <- IO.sleep(100.millis) + _ <- f.cancel + c <- r.get + _ <- IO { c mustEqual 2 } + } yield true + + p must completeAs(true) + } + + "propagate self-cancellation" in ticked { implicit ticker => + List(1, 2, 3, 4) + .parTraverseN(2) { (n: Int) => + if (n == 3) IO.canceled *> IO.never + else IO.pure(n) + } + .void must selfCancel + } + + "run finalizers when a task self-cancels" in ticked { implicit ticker => + val p = for { + r <- IO.ref(0) + fib <- List(1, 2, 3, 4) + .parTraverseN(2) { (n: Int) => + if (n == 3) IO.canceled *> IO.never + else IO.pure(n) + } + .onCancel(r.update(_ + 1)) + .void + .start + _ <- IO.sleep(100.millis) + c <- r.get + _ <- IO { c mustEqual 1 } + oc <- fib.join + } yield oc.isCanceled + + p must completeAs(true) + } + + "not run more than `n` tasks at a time" in real { + def task(counter: Ref[IO, Int], maximum: Ref[IO, Int]): IO[Unit] = { + val acq = counter.updateAndGet(_ + 1).flatMap { count => + maximum.update { max => if (count > max) count else max } + } + IO.asyncForIO.bracket(acq) { _ => IO.sleep(100.millis) }(_ => counter.update(_ - 1)) + } + + for { + maximum <- Ref.of[IO, Int](0) + counter <- Ref.of[IO, Int](0) + nCpu <- IO { Runtime.getRuntime().availableProcessors() } + n = java.lang.Math.max(nCpu, 2) + size = 4 * n + res <- (1 to size).toList.parTraverseN(n) { _ => task(counter, maximum) } + _ <- IO { res.size mustEqual size } + count <- counter.get + _ <- IO { count mustEqual 0 } + max <- maximum.get + _ <- IO { max must beLessThanOrEqualTo(n) } + } yield ok + } + + "run actually in parallel" in real { + val n = 4 + (1 to 2 * n) + .toList + .map { i => IO.sleep(1.second).as(i) } + .parSequenceN(n) + .timeout(3.seconds) + .flatMap { res => IO { res mustEqual (1 to 2 * n).toList } } + } + + "work for empty traverse" in ticked { implicit ticker => + List.empty[Int].parTraverseN(4) { _ => IO.never[String] } must completeAs( + List.empty[String]) + } + + "work for non-empty traverse (ticked)" in ticked { implicit ticker => + List(1).parTraverseN(4) { i => IO.pure(i.toString) } must completeAs(List("1")) + List(1, 2).parTraverseN(3) { i => IO.pure(i.toString) } must completeAs(List("1", "2")) + List(1, 2, 3).parTraverseN(2) { i => IO.pure(i.toString) } must completeAs( + List("1", "2", "3")) + List(1, 2, 3, 4).parTraverseN(1) { i => IO.pure(i.toString) } must completeAs( + List("1", "2", "3", "4")) + } + + "work for non-empty traverse (real)" in real { + for { + _ <- List(1).parTraverseN(4)(i => IO.pure(i.toString)).flatMap { r => + IO(r mustEqual List("1")) + } + _ <- List(1, 2).parTraverseN(3)(i => IO.pure(i.toString)).flatMap { r => + IO(r mustEqual List("1", "2")) + } + _ <- List(1, 2, 3).parTraverseN(2)(i => IO.pure(i.toString)).flatMap { r => + IO(r mustEqual List("1", "2", "3")) + } + _ <- List(1, 2, 3, 4).parTraverseN(1)(i => IO.pure(i.toString)).flatMap { r => + IO(r mustEqual List("1", "2", "3", "4")) + } + _ <- (1 to 10000).toList.parTraverseN(2)(i => IO.pure(i.toString)).flatMap { r => + IO(r mustEqual (1 to 10000).map(_.toString).toList) + } + } yield ok + } + + "be null-safe" in real { + for { + r1 <- List[String]("a", "b", null, "d", null).parTraverseN(2) { + case "a" => IO.pure(null) + case "b" => IO.pure("x") + case "d" => IO.pure(null) + case null => IO.pure("z") + } + _ <- IO { r1 mustEqual List(null, "x", "z", null, "z") } + } yield ok + } + + "run finalizers in parallel" in ticked { implicit ticker => + // this test also tests to ensure that we get the errored results rather than cancels + // note that the first two effects will have a Canceled outcome, while the third is Errored + // if we just go by first wins in sequence, then Canceled is the (incorrect) result + // first wins *in time* is the expected semantic here + val test = for { + latch1 <- IO.deferred[Unit] + latch2 <- IO.deferred[Unit] + + _ <- List(1, 2, 3).parTraverseN(3) { + case 1 => + IO.never.onCancel(latch1.complete(()) *> latch2.get) + + case 2 => + IO.never.onCancel(latch2.complete(()) *> latch1.get) + + case 3 => + IO.sleep(10.millis) *> IO.raiseError(new RuntimeException) + } + } yield () + + test.attempt.void must completeAs(()) + } } "parTraverseN_" should { @@ -1642,6 +1806,209 @@ class IOSpec extends BaseSpec with Discipline with IOPlatformSpecification { p must completeAs(true) } + "run finalizers when canceled" in ticked { implicit ticker => + val p = for { + r <- IO.ref(0) + + /* + * The exact series of steps here is: + * + * List(IO.never.onCancel, IO.unit, IO.never.onCancel) + * + * This is significant because we're limiting the parallelism to + * 2, meaning that we will hit a wall after IO.unit. HOWEVER, + * IO.unit completes immediately, so this test not only checks + * cancelation, it also tests that we move onto the third item + * after the second one completes even while the first is blocked. + * In other words, it's testing both cancelation and head of line + * behavior. + */ + f <- List(1, 2, 3) + .parTraverseN_(2) { i => + if (i == 2) IO.unit + else IO.never.onCancel(r.update(_ + 1)) + } + .start + + _ <- IO.sleep(100.millis) + _ <- f.cancel + c <- r.get + _ <- IO { c mustEqual 2 } + } yield true + + p must completeAs(true) + } + + "propagate self-cancellation" in ticked { implicit ticker => + List(1, 2, 3, 4) + .parTraverseN_(2) { (n: Int) => + if (n == 3) IO.canceled *> IO.never + else IO.pure(n) + } + .void must selfCancel + } + + "run finalizers when a task self-cancels" in ticked { implicit ticker => + val p = for { + r <- IO.ref(0) + fib <- List(1, 2, 3, 4) + .parTraverseN_(2) { (n: Int) => + if (n == 3) IO.canceled *> IO.never + else IO.pure(n) + } + .onCancel(r.update(_ + 1)) + .void + .start + _ <- IO.sleep(100.millis) + c <- r.get + _ <- IO { c mustEqual 1 } + oc <- fib.join + } yield oc.isCanceled + + p must completeAs(true) + } + + "run finalizers when a task self-cancels after everything started" in ticked { + implicit ticker => + val p = for { + ds <- IO.deferred[Unit].replicateA(3) + f <- ds + .parTraverseN_(4) { d => + (if (d eq ds(1)) IO.sleep(100.millis) *> IO.canceled + else IO.never).onCancel(d.complete(()).void) + } + .start + _ <- IO.sleep(50.millis) // all 3 fibers start (limit is 4) + oc <- f.join // after another 50ms, one of them self-cancels + _ <- IO { oc.isCanceled mustEqual true } + _ <- ds.traverse_( + _.get + ) // every finalizer must've ran, so every Deferred must be completed + } yield true + + p must completeAs(true) + } + + "run finalizers when a task errors after everything started" in ticked { + implicit ticker => + val p = for { + ds <- IO.deferred[Unit].replicateA(3) + f <- ds + .parTraverseN_(4) { d => + (if (d eq ds(1)) IO.sleep(100.millis) *> IO.raiseError(new Exception) + else IO.never).guarantee(d.complete(()).void) + } + .start + _ <- IO.sleep(50.millis) // all 3 fibers start (limit is 4) + oc <- f.join // after another 50ms, one of them errors + _ <- IO { oc.isError mustEqual true } + _ <- ds.traverse_( + _.get + ) // every finalizer must've ran, so every Deferred must be completed + } yield true + + p must completeAs(true) + } + + "not run more than `n` tasks at a time" in real { + def task(counter: Ref[IO, Int], maximum: Ref[IO, Int]): IO[Unit] = { + val acq = counter.updateAndGet(_ + 1).flatMap { count => + maximum.update { max => if (count > max) count else max } + } + IO.asyncForIO.bracket(acq) { _ => IO.sleep(100.millis) }(_ => counter.update(_ - 1)) + } + + for { + maximum <- Ref.of[IO, Int](0) + counter <- Ref.of[IO, Int](0) + nCpu <- IO { Runtime.getRuntime().availableProcessors() } + n = java.lang.Math.max(nCpu, 2) + size = 4 * n + _ <- (1 to size).toList.parTraverseN_(n) { _ => task(counter, maximum) } + count <- counter.get + _ <- IO { count mustEqual 0 } + max <- maximum.get + _ <- IO { max must beLessThanOrEqualTo(n) } + } yield ok + } + + "run actually in parallel" in real { + val n = 4 + (1 to 2 * n) + .toList + .map(_ => IO.sleep(1.second)) + .parSequenceN_(n) + .as(true) + .timeoutTo(3.seconds, IO.pure(false)) + .flatMap(res => IO { res must beTrue }) + } + + "work for empty traverse" in ticked { implicit ticker => + List.empty[Int].parTraverseN_(4) { _ => IO.never[String] } must completeAs(()) + } + + "work for non-empty traverse (ticked)" in ticked { implicit ticker => + List(1).parTraverseN_(4) { i => IO.pure(i.toString) } must completeAs(()) + List(1, 2).parTraverseN_(3) { i => IO.pure(i.toString) } must completeAs(()) + List(1, 2, 3).parTraverseN_(2) { i => IO.pure(i.toString) } must completeAs(()) + List(1, 2, 3, 4).parTraverseN_(1) { i => IO.pure(i.toString) } must completeAs(()) + } + + "work for non-empty traverse (real)" in real { + for { + _ <- List(1).parTraverseN_(4)(i => IO.pure(i.toString)).flatMap { r => + IO(r.mustEqual(())) + } + _ <- List(1, 2).parTraverseN_(3)(i => IO.pure(i.toString)).flatMap { r => + IO(r.mustEqual(())) + } + _ <- List(1, 2, 3).parTraverseN_(2)(i => IO.pure(i.toString)).flatMap { r => + IO(r.mustEqual(())) + } + _ <- List(1, 2, 3, 4).parTraverseN_(1)(i => IO.pure(i.toString)).flatMap { r => + IO(r.mustEqual(())) + } + _ <- (1 to 10000).toList.parTraverseN_(2)(i => IO.pure(i.toString)).flatMap { r => + IO(r.mustEqual(())) + } + } yield ok + } + + "be null-safe" in real { + for { + r1 <- List[String]("a", "b", null, "d", null).parTraverseN_(2) { + case "a" => IO.pure(null) + case "b" => IO.pure("x") + case "d" => IO.pure(null) + case null => IO.pure("z") + } + _ <- IO { r1 mustEqual (()) } // just trying to make sure we don't crash + } yield ok + } + + "run finalizers in parallel" in ticked { implicit ticker => + // this test also tests to ensure that we get the errored results rather than cancels + // note that the first two effects will have a Canceled outcome, while the third is Errored + // if we just go by first wins in sequence, then Canceled is the (incorrect) result + // first wins *in time* is the expected semantic here + val test = for { + latch1 <- IO.deferred[Unit] + latch2 <- IO.deferred[Unit] + + _ <- List(1, 2, 3).parTraverseN_(3) { + case 1 => + IO.never.onCancel(latch1.complete(()) *> latch2.get) + + case 2 => + IO.never.onCancel(latch2.complete(()) *> latch1.get) + + case 3 => + IO.sleep(10.millis) *> IO.raiseError(new RuntimeException) + } + } yield () + + test.attempt.void must completeAs(()) + } } "parallel" should {