diff --git a/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/PureConcGenerators.scala b/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/PureConcGenerators.scala index c6162d20d4..d0b92ffd1e 100644 --- a/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/PureConcGenerators.scala +++ b/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/PureConcGenerators.scala @@ -42,9 +42,8 @@ object PureConcGenerators { override def recursiveGen[B: Arbitrary: Cogen](deeper: GenK[PureConc[E, *]]) = super .recursiveGen[B](deeper) - .filterNot( - _._1 == "racePair" - ) // remove the racePair generator since it reifies nondeterminism, which cannot be law-tested + .filterNot(gen => + gen._1 == "racePair" || gen._1 == "join") // remove generators which reify nondeterminism and cannot be law-tested } implicit def arbitraryPureConc[E: Arbitrary: Cogen, A: Arbitrary: Cogen] diff --git a/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/TimeT.scala b/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/TimeT.scala index a74a18e241..ce63a742d9 100644 --- a/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/TimeT.scala +++ b/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/TimeT.scala @@ -18,7 +18,7 @@ package cats.effect package kernel package testkit -import cats.{~>, Group, Monad, Monoid, Order} +import cats.{~>, Eq, Group, Monad, Monoid, Order} import cats.data.Kleisli import cats.syntax.all._ @@ -90,6 +90,9 @@ private[effect] object TimeT { a.map(_.inverse()) } + implicit def eqTimeT[F[_], A](implicit FA: Eq[F[A]]): Eq[TimeT[F, A]] = + Eq.by(TimeT.run) + implicit def orderTimeT[F[_], A](implicit FA: Order[F[A]]): Order[TimeT[F, A]] = Order.by(TimeT.run) @@ -111,15 +114,69 @@ private[effect] object TimeT { val forkA = time.fork() val forkB = time.fork() - // TODO this doesn't work (yet) because we need to force the "faster" effect to win the race, which right now isn't happening - F.racePair(fa.run(forkA), fb.run(forkB)).map { + def liftOutcome[C](oc: Outcome[F, E, C]): Outcome[TimeT[F, *], E, C] = + oc.mapK(TimeT.liftK[F]) + + F.racePair(fa.run(forkA), fb.run(forkB)).flatMap { case Left((oca, delegate)) => - time.now = forkA.now - Left((oca.mapK(TimeT.liftK[F]), fiberize(forkB, delegate))) + F.onCancel(F.race(delegate.join, F.cede), delegate.cancel).map { + case Left(ocb) if forkB.now < forkA.now => + time.now = forkB.now + Right((completedFiber(forkA, liftOutcome(oca)), liftOutcome(ocb))) + + case _ => + time.now = forkA.now + Left((liftOutcome(oca), fiberize(forkB, delegate))) + } case Right((delegate, ocb)) => - time.now = forkB.now - Right((fiberize(forkA, delegate), ocb.mapK(TimeT.liftK[F]))) + F.onCancel(F.race(delegate.join, F.cede), delegate.cancel).map { + case Left(oca) if forkA.now < forkB.now => + time.now = forkA.now + Left((liftOutcome(oca), completedFiber(forkB, liftOutcome(ocb)))) + + case _ => + time.now = forkB.now + Right((fiberize(forkA, delegate), liftOutcome(ocb))) + } + } + } + + override def race[A, B](fa: TimeT[F, A], fb: TimeT[F, B]): TimeT[F, Either[A, B]] = + uncancelable { poll => + poll(racePair(fa, fb)).flatMap { + case Left((oc, f)) => + oc match { + case Outcome.Succeeded(fa) => f.cancel *> fa.map(Left(_)) + case Outcome.Errored(ea) => f.cancel *> raiseError(ea) + case Outcome.Canceled() => + f.cancel *> poll(f.join) flatMap { + case Outcome.Succeeded(fb) => fb.map(Right(_)) + case Outcome.Errored(eb) => raiseError(eb) + case Outcome.Canceled() => poll(canceled) *> never + } + } + + case Right((f, oc)) => + oc match { + case Outcome.Succeeded(fb) => f.cancel *> fb.map(Right(_)) + case Outcome.Errored(eb) => f.cancel *> raiseError(eb) + case Outcome.Canceled() => + f.cancel *> poll(f.join) flatMap { + case Outcome.Succeeded(fa) => fa.map(Left(_)) + case Outcome.Errored(ea) => raiseError(ea) + case Outcome.Canceled() => poll(canceled) *> never + } + } + } + } + + override def raceOutcome[A, B](fa: TimeT[F, A], fb: TimeT[F, B]) + : TimeT[F, Either[Outcome[TimeT[F, *], E, A], Outcome[TimeT[F, *], E, B]]] = + uncancelable { poll => + poll(racePair(fa, fb)).flatMap { + case Left((oc, f)) => f.cancel.as(Left(oc)) + case Right((f, oc)) => f.cancel.as(Right(oc)) } } @@ -156,5 +213,20 @@ private[effect] object TimeT { } } } + + private[this] def completedFiber[A]( + forked: Time, + outcome: Outcome[TimeT[F, *], E, A]): Fiber[TimeT[F, *], E, A] = + new Fiber[TimeT[F, *], E, A] { + + val cancel = + unit + + val join = + Kleisli { outerTime => + outerTime.now = outerTime.now.max(forked.now) + F.pure(outcome) + } + } } } diff --git a/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/pure.scala b/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/pure.scala index c784639d3f..e538126ae8 100644 --- a/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/pure.scala +++ b/kernel-testkit/shared/src/main/scala/cats/effect/kernel/testkit/pure.scala @@ -41,10 +41,132 @@ object pure { implicit val eq: Eq[MaskId] = Eq.fromUniversalEquals[MaskId] } - final case class FiberCtx[E]( - self: PureFiber[E, _], - masks: List[MaskId] = Nil, - finalizers: List[PureConc[E, Unit]] = Nil) + private[pure] final case class MaskFrame(id: MaskId) + + // None defers finalizer selection until observation; Some scopes an in-fiber request. + private[pure] final case class CancelationSignal[E]( + finalizers: Option[List[PureConc[E, Unit]]]) + + private[pure] final class CancelationListenerId + + private[pure] object CancelationListenerId { + implicit val eq: Eq[CancelationListenerId] = + Eq.fromUniversalEquals[CancelationListenerId] + } + + private[pure] final case class CancelationListener[E]( + id: CancelationListenerId, + action: PureConc[E, Unit]) + + private sealed trait MaskUpdate + + private object MaskUpdate { + case object Removed extends MaskUpdate + case object Shadowed extends MaskUpdate + case object Absent extends MaskUpdate + } + + final class FiberCtx[E] private[pure] ( + val self: PureFiber[E, _], + val masks: List[MaskId], + val finalizers: List[PureConc[E, Unit]], + private[pure] val selfCancelationBoundary: Option[Int], + private[pure] val finalizing: Boolean) + extends Product3[PureFiber[E, _], List[MaskId], List[PureConc[E, Unit]]] + with Serializable { + + def this(self: PureFiber[E, _], masks: List[MaskId], finalizers: List[PureConc[E, Unit]]) = + this(self, masks, finalizers, None, false) + + def this(self: PureFiber[E, _], masks: List[MaskId]) = + this(self, masks, Nil) + + def this(self: PureFiber[E, _]) = + this(self, Nil, Nil) + + def copy( + self: PureFiber[E, _] = this.self, + masks: List[MaskId] = this.masks, + finalizers: List[PureConc[E, Unit]] = this.finalizers): FiberCtx[E] = + FiberCtx.internal(self, masks, finalizers, selfCancelationBoundary, finalizing) + + private[pure] def withSelfCancelationBoundary(boundary: Option[Int]): FiberCtx[E] = + FiberCtx.internal(self, masks, finalizers, boundary, finalizing) + + private[pure] def withFinalizing(value: Boolean): FiberCtx[E] = + FiberCtx.internal(self, masks, finalizers, selfCancelationBoundary, value) + + def _1: PureFiber[E, _] = self + + def _2: List[MaskId] = masks + + def _3: List[PureConc[E, Unit]] = finalizers + + override def canEqual(that: Any): Boolean = + that.isInstanceOf[FiberCtx[_]] + + override def productArity: Int = 3 + + override def productElement(n: Int): Any = + n match { + case 0 => self + case 1 => masks + case 2 => finalizers + case _ => throw new IndexOutOfBoundsException(n.toString) + } + + override def productPrefix: String = "FiberCtx" + + override def equals(that: Any): Boolean = + that match { + case that: FiberCtx[_] => + that.canEqual(this) && + self == that.self && + masks == that.masks && + finalizers == that.finalizers + case _ => + false + } + + override def hashCode: Int = + (self, masks, finalizers).## + + override def toString: String = + s"FiberCtx($self,$masks,$finalizers)" + } + + object FiberCtx { + def $lessinit$greater$default$2[E]: List[MaskId] = + Nil + + def $lessinit$greater$default$3[E]: Nil.type = + Nil + + def apply[E]( + self: PureFiber[E, _], + masks: List[MaskId] = Nil, + finalizers: List[PureConc[E, Unit]] = Nil): FiberCtx[E] = + new FiberCtx(self, masks, finalizers, None, false) + + def tupled[E]: ((PureFiber[E, _], List[MaskId], List[PureConc[E, Unit]])) => FiberCtx[E] = { + case (self, masks, finalizers) => apply(self, masks, finalizers) + } + + def curried[E]: PureFiber[E, _] => List[MaskId] => List[PureConc[E, Unit]] => FiberCtx[E] = + self => masks => finalizers => apply(self, masks, finalizers) + + private[pure] def internal[E]( + self: PureFiber[E, _], + masks: List[MaskId], + finalizers: List[PureConc[E, Unit]], + selfCancelationBoundary: Option[Int], + finalizing: Boolean): FiberCtx[E] = + new FiberCtx(self, masks, finalizers, selfCancelationBoundary, finalizing) + + def unapply[E]( + ctx: FiberCtx[E]): Option[(PureFiber[E, _], List[MaskId], List[PureConc[E, Unit]])] = + Some((ctx.self, ctx.masks, ctx.finalizers)) + } type ResolvedPC[E, A] = ThreadT[IdOC[E, *], A] @@ -69,8 +191,19 @@ object pure { val back = Kleisli.ask[IdOC[E, *], FiberCtx[E]] map { ctx => val checker = ctx .self - .realizeCancelation - .ifM(ApplicativeThread[PureConc[E, *]].done, ().pure[PureConc[E, *]]) + .hasActivePoll + .ifM( + ().pure[PureConc[E, *]], + ctx + .self + .isFinalizing + .ifM( + ().pure[PureConc[E, *]], + ctx + .self + .realizeCancelationWith(ctx) + .ifM(ApplicativeThread[PureConc[E, *]].done[Unit], ().pure[PureConc[E, *]])) + ) checker >> mvarLiftF(ThreadT.liftF(ka)) } @@ -97,45 +230,62 @@ object pure { type Main[X] = MVarR[ResolvedPC[E, *], X] MVar.empty[Main, Outcome[PureConc[E, *], E, A]].flatMap { state0 => - val state = state0[Main] - - val fiber = new PureFiber[E, A](state0) - - val identified = canceled mapF { ta => - val fk = new (FiberR[E, *] ~> IdOC[E, *]) { - def apply[a](ke: FiberR[E, a]) = - ke.run(FiberCtx(fiber)) - } - - ta.mapK(fk) - } - - import Outcome._ - - val body = identified flatMap { a => - state.tryPut(Succeeded(a.pure[PureConc[E, *]])) - } handleErrorWith { e => state.tryPut(Errored(e)) } - - val results = state.read.flatMap { - case Canceled() => (Outcome.Canceled(): IdOC[E, A]).pure[Main] - case Errored(e) => (Outcome.Errored(e): IdOC[E, A]).pure[Main] - - case Succeeded(fa) => - val identifiedCompletion = fa.mapF { ta => - val fk = new (FiberR[E, *] ~> IdOC[E, *]) { - def apply[a](ke: FiberR[E, a]) = - ke.run(FiberCtx(fiber)) + MVar.empty[Main, CancelationSignal[E]] flatMap { canceled0 => + MVar[Main, List[MaskFrame]](Nil) flatMap { masks => + MVar[Main, List[CancelationListener[E]]](Nil) flatMap { cancelationListeners => + MVar[Main, Boolean](false) flatMap { finalizing => + MVar[Main, Int](0) flatMap { activePolls => + val state = state0[Main] + val fiber = + new PureFiber[E, A]( + state0, + canceled0, + masks, + cancelationListeners, + finalizing, + activePolls) + + val identified = canceled mapF { ta => + val fk = new (FiberR[E, *] ~> IdOC[E, *]) { + def apply[a](ke: FiberR[E, a]) = + ke.run(FiberCtx(fiber)) + } + + ta.mapK(fk) + } + + import Outcome._ + + val body = identified flatMap { a => + state.tryPut(Succeeded(a.pure[PureConc[E, *]])) + } handleErrorWith { e => state.tryPut(Errored(e)) } + + val results = state.read.flatMap { + case Canceled() => (Outcome.Canceled(): IdOC[E, A]).pure[Main] + case Errored(e) => (Outcome.Errored(e): IdOC[E, A]).pure[Main] + + case Succeeded(fa) => + val identifiedCompletion = fa.mapF { ta => + val fk = new (FiberR[E, *] ~> IdOC[E, *]) { + def apply[a](ke: FiberR[E, a]) = + ke.run(FiberCtx(fiber)) + } + + ta.mapK(fk) + } + + identifiedCompletion.map(a => + Succeeded[Id, E, A](a): IdOC[E, A]) handleError { e => Errored(e) } + } + + Kleisli.ask[ResolvedPC[E, *], MVar.Universe].map { u => + ApplicativeThread[ResolvedPC[E, *]].start(body.run(u)) >> results.run(u) + } + } } - - ta.mapK(fk) - } - - identifiedCompletion.map(a => Succeeded[Id, E, A](a): IdOC[E, A]) handleError { e => - Errored(e) } + } } - - Kleisli.ask[ResolvedPC[E, *], MVar.Universe].map { u => body.run(u) >> results.run(u) } } } @@ -164,8 +314,9 @@ object pure { case (List(results), _) => results.mapK(optLift) case (_, false) => Outcome.Succeeded(None) - // we could make a writer that only receives one object, but that seems meh. just pretend we deadlocked - case _ => Outcome.Succeeded(None) + // in the case of never and such, we are awaiting the async cancel monitor + // this scenario only arises if the main fiber self cancels + case _ => Outcome.Canceled() } } @@ -197,20 +348,21 @@ object pure { } def canceled: PureConc[E, Unit] = - Thread.annotate("canceled") { - withCtx { ctx => - if (ctx.masks.isEmpty) - uncancelable(_ => ctx.self.cancel >> ctx.finalizers.sequence_ >> Thread.done) - else - ctx.self.cancel - } - } + Thread.annotate("canceled")(withCtx { ctx => + ctx.self.cancelAndRealizeWith(ctx).ifM(Thread.done, unit) + }) def cede: PureConc[E, Unit] = - Thread.cede + withCtx { ctx => + Thread.cede *> + ctx.self.realizeCancelationWith(ctx).ifM(Thread.done, unit) + } def never[A]: PureConc[E, A] = - Thread.annotate("never")(Thread.done[A]) + withCtx[E, A] { ctx => + // we monitor for asynchronous cancelation. if we're masked, this won't cancel and we hang + Thread.annotate("never")(ctx.self.awaitCancelationWith(ctx) *> Thread.done) + } def ref[A](a: A): PureConc[E, Ref[PureConc[E, *], A]] = MVar[PureConc[E, *], A](a).flatMap(mVar => Kleisli.pure(unsafeRef(mVar))) @@ -218,6 +370,9 @@ object pure { def deferred[A]: PureConc[E, Deferred[PureConc[E, *], A]] = MVar.empty[PureConc[E, *], A].flatMap(mVar => Kleisli.pure(unsafeDeferred(mVar))) + private[this] def interruptible[A](ctx: FiberCtx[E], fa: PureConc[E, A]): PureConc[E, A] = + ctx.self.interruptible(ctx)(fa) + private def unsafeRef[A](mVar: MVar[A]): Ref[PureConc[E, *], A] = new Ref[PureConc[E, *], A] { override def get: PureConc[E, A] = mVar.read[PureConc[E, *]] @@ -273,7 +428,8 @@ object pure { private def unsafeDeferred[A](mVar: MVar[A]): Deferred[PureConc[E, *], A] = new Deferred[PureConc[E, *], A] { - override def get: PureConc[E, A] = mVar.read[PureConc[E, *]] + override def get: PureConc[E, A] = + withCtx { ctx => interruptible(ctx, mVar.read[PureConc[E, *]]) } override def complete(a: A): PureConc[E, Boolean] = mVar.tryPut[PureConc[E, *]](a) @@ -283,30 +439,166 @@ object pure { def start[A](fa: PureConc[E, A]): PureConc[E, Fiber[PureConc[E, *], E, A]] = Thread.annotate("start", true) { MVar.empty[PureConc[E, *], Outcome[PureConc[E, *], E, A]].flatMap { state => - val fiber = new PureFiber[E, A](state) + MVar.empty[PureConc[E, *], CancelationSignal[E]] flatMap { canceled => + MVar[PureConc[E, *], List[MaskFrame]](Nil) flatMap { masks => + MVar[PureConc[E, *], List[CancelationListener[E]]](Nil) flatMap { + cancelationListeners => + MVar[PureConc[E, *], Boolean](false) flatMap { finalizing => + MVar[PureConc[E, *], Int](0) flatMap { activePolls => + val fiber = + new PureFiber[E, A]( + state, + canceled, + masks, + cancelationListeners, + finalizing, + activePolls) + + // the tryPut here is interesting: it encodes first-wins semantics on cancelation/completion + val body = guaranteeCase(fa)(state.tryPut[PureConc[E, *]](_).void) + val identified = localCtx(FiberCtx(fiber), body) + Thread.start(identified.attempt.void).as(fiber) + } + } + } + } + } + } + } - // the tryPut here is interesting: it encodes first-wins semantics on cancelation/completion - val body = guaranteeCase(fa)(state.tryPut[PureConc[E, *]](_).void) - val identified = localCtx(FiberCtx(fiber), body) - Thread.start(identified.attempt.void).as(fiber) + override def racePair[A, B](fa: PureConc[E, A], fb: PureConc[E, B]): PureConc[ + E, + Either[ + (Outcome[PureConc[E, *], E, A], Fiber[PureConc[E, *], E, B]), + (Fiber[PureConc[E, *], E, A], Outcome[PureConc[E, *], E, B])]] = + uncancelable { poll => + for { + result <- deferred[ + Either[Outcome[PureConc[E, *], E, A], Outcome[PureConc[E, *], E, B]]] + + fibA <- start(fa) + fibB <- start(fb) + + _ <- start( + fibA + .join + .flatMap(oc => + result + .complete(Left(oc): Either[ + Outcome[PureConc[E, *], E, A], + Outcome[PureConc[E, *], E, B]]) + .void)) + _ <- start( + fibB + .join + .flatMap(oc => + result + .complete(Right(oc): Either[ + Outcome[PureConc[E, *], E, A], + Outcome[PureConc[E, *], E, B]]) + .void)) + + back <- onCancel( + poll(result.get), + for { + canA <- start(fibA.cancel) + canB <- start(fibB.cancel) + + _ <- canA.join + _ <- canB.join + } yield ()) + } yield back match { + case Left(oc) => Left((oc, fibB)) + case Right(oc) => Right((fibA, oc)) } } def uncancelable[A](body: Poll[PureConc[E, *]] => PureConc[E, A]): PureConc[E, A] = Thread.annotate("uncancelable", true) { - val mask = new MaskId + withCtx { ctx => + val mask = new MaskId + val selfCancelationBoundary = + ctx.selfCancelationBoundary.getOrElse(ctx.finalizers.length) + + val self = ctx.self + def updateMasks[B](f: List[MaskFrame] => (List[MaskFrame], B)): PureConc[E, B] = + self.masks.read[PureConc[E, *]].flatMap { ms => + val (updated, b) = f(ms) + self.masks.swap[PureConc[E, *]](updated).as(b) + } + + val addF = updateMasks(ms => (MaskFrame(mask) :: ms, ())) + val removeF = updateMasks { + case MaskFrame(`mask`) :: ms => (ms, MaskUpdate.Removed) + case ms if ms.exists(_.id === mask) => (ms, MaskUpdate.Shadowed) + case ms => (ms, MaskUpdate.Absent) + } - val poll = new Poll[PureConc[E, *]] { - def apply[a](fa: PureConc[E, a]) = - withCtx { ctx => - val ctx2 = ctx.copy(masks = ctx.masks.dropWhile(mask === _)) - localCtx(ctx2, fa.attempt <* ctx.self.realizeCancelation).rethrow + def restore(update: MaskUpdate) = + update match { + case MaskUpdate.Removed => self.exitPoll *> addF + case MaskUpdate.Shadowed | MaskUpdate.Absent => unit } - } - withCtx { ctx => - val ctx2 = ctx.copy(masks = mask :: ctx.masks) - localCtx(ctx2, body(poll)) + val poll = new Poll[PureConc[E, *]] { + def apply[a](fa: PureConc[E, a]) = + withCtx { callCtx => + if (callCtx.self eq self) + removeF.flatMap { update => + val restoreF = restore(update) + val pollCtx = update match { + case MaskUpdate.Removed => + callCtx.withSelfCancelationBoundary( + callCtx + .selfCancelationBoundary + .orElse(Some(selfCancelationBoundary))) + + case MaskUpdate.Shadowed | MaskUpdate.Absent => + callCtx + } + + val enterF = update match { + case MaskUpdate.Removed => self.enterPoll + case MaskUpdate.Shadowed | MaskUpdate.Absent => unit + } + + enterF *> + localCtx( + pollCtx, + onCancel( + self + .realizeCancelationWith(pollCtx) + .ifM( + Thread.done, + fa.attempt.flatMap { result => + self + .realizeCancelationWith(pollCtx) + .ifM( + Thread.done, + restoreF *> result.pure[PureConc[E, *]].rethrow) + }), + restoreF + ) + ) + } + else fa + } + } + + val runBody = + addF *> body(poll).attempt.flatMap { result => + removeF.flatMap { + case MaskUpdate.Removed => + val back = result.pure[PureConc[E, *]].rethrow + + self.realizeCancelationWith(ctx).ifM(Thread.done, back) + + case MaskUpdate.Shadowed | MaskUpdate.Absent => + result.pure[PureConc[E, *]].rethrow + } + } + + onCancel(runBody, removeF.void) } } @@ -315,7 +607,7 @@ object pure { Defer[PureConc[E, *]].defer(pure(new Unique.Token())) def forceR[A, B](fa: PureConc[E, A])(fb: PureConc[E, B]): PureConc[E, B] = - Thread.annotate("forceR")(productR(attempt(fa))(fb)) + Thread.annotate("forceR")(productR(handleError(fa.void)(_ => ()))(fb)) def flatMap[A, B](fa: PureConc[E, A])(f: A => PureConc[E, B]): PureConc[E, B] = M.flatMap(fa)(f) @@ -363,34 +655,249 @@ object pure { } // todo: MVar is not Serializable, release then update here - final class PureFiber[E, A](val state0: MVar[Outcome[PureConc[E, *], E, A]]) + final class PureFiber[E, A]( + val state0: MVar[Outcome[PureConc[E, *], E, A]], + private[this] val canceled0: MVar[CancelationSignal[E]], + private[pure] val masks: MVar[List[MaskFrame]], + private[this] val cancelationListeners: MVar[List[CancelationListener[E]]], + private[this] val finalizing: MVar[Boolean], + private[this] val activePolls: MVar[Int]) extends Fiber[PureConc[E, *], E, A] with Serializable { + def this(state0: MVar[Outcome[PureConc[E, *], E, A]]) = + this(state0, null, null, null, null, null) + private[this] val state = state0[PureConc[E, *]] - private[pure] val canceled: PureConc[E, Boolean] = - state.tryRead.map(_.map(_.fold(true, _ => false, _ => false)).getOrElse(false)) - - private[pure] val realizeCancelation: PureConc[E, Boolean] = - withCtx { ctx => - val checkM = ctx.masks.isEmpty.pure[PureConc[E, *]] - - checkM.ifM( - canceled.ifM( - // if unmasked and canceled, finalize - allocateForPureConc[E].uncancelable(_ => ctx.finalizers.sequence_.as(true)), - // if unmasked but not canceled, ignore - false.pure[PureConc[E, *]] - ), - // if masked, ignore cancelation state but retain until unmasked - false.pure[PureConc[E, *]] + private[pure] val currentMasks: PureConc[E, List[MaskFrame]] = + if (masks eq null) List.empty[MaskFrame].pure[PureConc[E, *]] + else masks.read[PureConc[E, *]] + + private[pure] val hasActivePoll: PureConc[E, Boolean] = + if (activePolls eq null) false.pure[PureConc[E, *]] + else activePolls.read[PureConc[E, *]].map(_ > 0) + + private[pure] val enterPoll: PureConc[E, Unit] = + if (activePolls eq null) ().pure[PureConc[E, *]] + else + activePolls.read[PureConc[E, *]].flatMap { n => + activePolls.swap[PureConc[E, *]](n + 1).void + } + + private[pure] val exitPoll: PureConc[E, Unit] = + if (activePolls eq null) ().pure[PureConc[E, *]] + else + activePolls.read[PureConc[E, *]].flatMap { n => + activePolls.swap[PureConc[E, *]]((n - 1) max 0).void + } + + private[pure] def registerCancelationListener( + notify: PureConc[E, Unit]): PureConc[E, CancelationListenerId] = { + val id = new CancelationListenerId + + if (cancelationListeners eq null) id.pure[PureConc[E, *]] + else + cancelationListeners.read[PureConc[E, *]].flatMap { listeners => + cancelationListeners + .swap[PureConc[E, *]](CancelationListener(id, notify) :: listeners) + .as(id) + } + } + + private[pure] def removeCancelationListener(id: CancelationListenerId): PureConc[E, Unit] = + if (cancelationListeners eq null) ().pure[PureConc[E, *]] + else + cancelationListeners.read[PureConc[E, *]].flatMap { listeners => + cancelationListeners.swap[PureConc[E, *]](listeners.filterNot(_.id === id)).void + } + + private[this] def notifyCancelationListeners: PureConc[E, Unit] = + if (cancelationListeners eq null) ().pure[PureConc[E, *]] + else cancelationListeners.swap[PureConc[E, *]](Nil).flatMap(_.traverse_(_.action)) + + private[pure] def interruptible[B](ctx: FiberCtx[E])(fb: PureConc[E, B]): PureConc[E, B] = { + val Thread = ApplicativeThread[PureConc[E, *]] + + ctx.self.currentMasks.flatMap { + case Nil => + MVar.empty[PureConc[E, *], Option[B]].flatMap { signal => + val notifyCancelation = signal.tryPut[PureConc[E, *]](None).void + + ctx.self.registerCancelationListener(notifyCancelation).flatMap { listener => + val awaitCompletion = + Thread.start(fb.flatMap(b => signal.tryPut[PureConc[E, *]](Some(b)).void)) + + val checkCancelation = + signal.tryRead[PureConc[E, *]].flatMap { + case Some(_) => ().pure[PureConc[E, *]] + case None => + ctx + .self + .realizeCancelationWith(ctx) + .ifM(notifyCancelation, ().pure[PureConc[E, *]]) + } + + awaitCompletion *> + checkCancelation *> + signal.read[PureConc[E, *]].flatMap { + case Some(b) => + ctx.self.removeCancelationListener(listener).as(b) + + case None => + ctx.self.removeCancelationListener(listener) *> + ctx.self.realizeCancelationWith(ctx) *> + Thread.done + } + } + } + + case _ => + fb + } + } + + private[pure] val isFinalizing: PureConc[E, Boolean] = + if (finalizing eq null) false.pure[PureConc[E, *]] + else finalizing.read[PureConc[E, *]] + + private[this] def setFinalizing(value: Boolean): PureConc[E, Unit] = + if (finalizing eq null) ().pure[PureConc[E, *]] + else finalizing.swap[PureConc[E, *]](value).void + + private[this] def finalizeWith( + ctx: FiberCtx[E], + finalizers: List[PureConc[E, Unit]]): PureConc[E, Boolean] = + localCtx( + ctx.copy(finalizers = Nil).withFinalizing(true), + allocateForPureConc[E].uncancelable(_ => finalizers.sequence_) *> + (state0.tryPut[PureConc[E, *]](Outcome.Canceled()).flatMap { + case true => true.pure[PureConc[E, *]] + case false => + state.read.map { + case Outcome.Canceled() => true + case _ => false + } + } <* setFinalizing(false)) + ) + + private[this] def whileFinalizing[B](ctx: FiberCtx[E])(fb: PureConc[E, B]): PureConc[E, B] = + localCtx(ctx.copy(finalizers = Nil).withFinalizing(true), setFinalizing(true) *> fb) + + private[this] def finalizationOutcome: PureConc[E, Boolean] = + state.read.map { + case Outcome.Canceled() => true + case _ => false + } + + private[this] def cancelationFinalizers( + signal: CancelationSignal[E], + ctx: FiberCtx[E]): List[PureConc[E, Unit]] = + signal.finalizers match { + case None => ctx.finalizers + case Some(Nil) if ctx.selfCancelationBoundary.nonEmpty => + cancelationBoundaryFinalizers(ctx) + case Some(finalizers) => finalizers + } + + private[this] def realizeCancelationWithSignal( + ctx: FiberCtx[E], + signal: CancelationSignal[E]): PureConc[E, Boolean] = + if (ctx.finalizing) false.pure[PureConc[E, *]] + else + isFinalizing.ifM( + finalizationOutcome, + ctx + .self + .currentMasks + .map(_.isEmpty) + .ifM( + whileFinalizing(ctx)(finalizeWith(ctx, cancelationFinalizers(signal, ctx))), + false.pure[PureConc[E, *]] + ) + ) + + private[pure] def realizeCancelationWith(ctx: FiberCtx[E]): PureConc[E, Boolean] = + if (ctx.finalizing) false.pure[PureConc[E, *]] + else + isFinalizing.ifM( + finalizationOutcome, + canceled0.tryRead[PureConc[E, *]].flatMap { + case Some(signal) => realizeCancelationWithSignal(ctx, signal) + case None => false.pure[PureConc[E, *]] + } ) + + private[this] def cancelationBoundaryFinalizers(ctx: FiberCtx[E]): List[PureConc[E, Unit]] = + ctx.selfCancelationBoundary match { + case Some(boundary) => ctx.finalizers.take((ctx.finalizers.length - boundary) max 0) + case None => Nil } - val cancel: PureConc[E, Unit] = state.tryPut(Outcome.Canceled()).void + private[pure] def awaitCancelationWith(ctx: FiberCtx[E]): PureConc[E, Boolean] = { + def blocked = + MVar.empty[PureConc[E, *], Unit].flatMap(_.read[PureConc[E, *]]).as(false) + + if (ctx.finalizing) blocked + else + ctx.self.currentMasks.flatMap { + case Nil => + isFinalizing.ifM( + canceled0.tryRead[PureConc[E, *]].map(_.isEmpty), + canceled0.read[PureConc[E, *]].flatMap(realizeCancelationWithSignal(ctx, _))) + + case _ => + isFinalizing.ifM(canceled0.tryRead[PureConc[E, *]].map(_.isEmpty), blocked) + } + } + + private[pure] def cancelAndRealizeWith(ctx: FiberCtx[E]): PureConc[E, Boolean] = + if (ctx.finalizing) false.pure[PureConc[E, *]] + else + isFinalizing.ifM( + ctx.self.currentMasks.map(_.isEmpty), + ctx + .self + .currentMasks + .flatMap { + case Nil => + whileFinalizing(ctx) { + canceled0 + .tryPut[PureConc[E, *]](CancelationSignal[E](Some(ctx.finalizers))) + .flatMap { + case true => + notifyCancelationListeners *> finalizeWith(ctx, ctx.finalizers) + case false => finalizeWith(ctx, ctx.finalizers) + } + } + + case _ => + requestCancelation(Some(Nil)).as(false) + } + ) + + private[this] def requestCancelation( + finalizers: Option[List[PureConc[E, Unit]]]): PureConc[E, Unit] = + canceled0 + .tryPut[PureConc[E, *]](CancelationSignal[E](finalizers)) + .flatMap(inserted => + if (inserted) notifyCancelationListeners else ().pure[PureConc[E, *]]) val join: PureConc[E, Outcome[PureConc[E, *], E, A]] = - state.read + if (canceled0 eq null) state.read + else { + withCtx { ctx => ctx.self.interruptible(ctx)(state.read) } + } + + val cancel: PureConc[E, Unit] = + if (canceled0 eq null) state.tryPut(Outcome.Canceled()).void + else + allocateForPureConc[E].uncancelable { _ => + state.tryRead.flatMap { + case Some(_) => ().pure[PureConc[E, *]] + case None => + requestCancelation(None) *> state.read.void + } + } } } diff --git a/laws/shared/src/test/scala/cats/effect/laws/PureConcSpec.scala b/laws/shared/src/test/scala/cats/effect/laws/PureConcSpec.scala index 5b367f4037..c9c6da67db 100644 --- a/laws/shared/src/test/scala/cats/effect/laws/PureConcSpec.scala +++ b/laws/shared/src/test/scala/cats/effect/laws/PureConcSpec.scala @@ -17,8 +17,9 @@ package cats.effect package laws +import cats.{Eq, Order} import cats.effect.kernel.testkit.{pure, OutcomeGenerators, PureConcGenerators, TimeT} -import cats.effect.kernel.testkit.TimeT._ +import cats.effect.kernel.testkit.TimeT.{eqTimeT => _, orderTimeT => _, _} import cats.effect.kernel.testkit.pure._ import cats.laws.discipline.arbitrary._ @@ -28,13 +29,27 @@ import org.typelevel.discipline.specs2.mutable.Discipline import scala.concurrent.duration._ -class PureConcSpec extends Specification with Discipline with BaseSpec { +private[laws] trait PureConcSpecLowPriorityTimeTInstances { + implicit def orderTimeTPureConcFiniteDuration( + implicit FA: Order[PureConc[Int, FiniteDuration]]) + : Order[TimeT[PureConc[Int, *], FiniteDuration]] = + TimeT.orderTimeT +} + +class PureConcSpec + extends Specification + with Discipline + with PureConcSpecLowPriorityTimeTInstances { import PureConcGenerators._ import OutcomeGenerators._ implicit def exec(fb: TimeT[PureConc[Int, *], Boolean]): Prop = Prop(pure.run(TimeT.run(fb)).fold(false, _ => false, _.getOrElse(false))) + implicit def eqTimeTPureConc[A]( + implicit FA: Eq[PureConc[Int, A]]): Eq[TimeT[PureConc[Int, *], A]] = + TimeT.eqTimeT + "parallel utilities" should { import cats.effect.kernel.{GenConcurrent, Outcome} import cats.effect.kernel.implicits._ @@ -49,6 +64,8 @@ class PureConcSpec extends Specification with Discipline with BaseSpec { } "short-circuit on canceled" in { + pure.run(F.canceled) mustEqual Outcome.Canceled() + pure.run((F.never[Unit], F.canceled).parTupled) mustEqual Outcome.Canceled() pure.run((F.never[Unit], F.canceled).parTupled.start.flatMap(_.join)) mustEqual Outcome .Succeeded(Some(Outcome.canceled[F, Nothing, Unit])) pure.run((F.canceled, F.never[Unit]).parTupled.start.flatMap(_.join)) mustEqual Outcome @@ -75,6 +92,352 @@ class PureConcSpec extends Specification with Discipline with BaseSpec { } } + "core PC state machine" should { + import cats.effect.kernel.{GenConcurrent, GenTemporal, Outcome} + import cats.effect.kernel.implicits._ + import cats.syntax.all._ + + type F[A] = PureConc[Int, A] + val F = GenConcurrent[F] + + "run finalizers when canceling never" in { + val t = for { + c <- F.ref(0) + latch <- F.deferred[Unit] + fib <- F.start((latch.complete(()) *> F.never[Unit]).onCancel(c.update(_ + 1))) + _ <- latch.get + _ <- fib.cancel + v <- c.get + } yield v + + pure.run(t) mustEqual Outcome.Succeeded(Some(1)) + } + + "run finalizers when canceling Deferred#get" in { + val t = for { + c <- F.ref(0) + latch <- F.deferred[Unit] + hang <- F.deferred[Unit] + fib <- F.start((latch.complete(()) *> hang.get).onCancel(c.update(_ + 1))) + _ <- latch.get + _ <- fib.cancel + v <- c.get + } yield v + + pure.run(t) mustEqual Outcome.Succeeded(Some(1)) + } + + "run finalizers when canceling Fiber#join" in { + val t = for { + c <- F.ref(0) + latch <- F.deferred[Unit] + hang <- F.start(F.never[Unit]) + fib <- F.start((latch.complete(()) *> hang.join).onCancel(c.update(_ + 1))) + _ <- latch.get + _ <- fib.cancel + v <- c.get + } yield v + + pure.run(t) mustEqual Outcome.Succeeded(Some(1)) + } + + "hang when canceling uncancelable never" in { + val t = for { + latch <- F.deferred[Unit] + f <- F.start((latch.complete(()) *> F.never[Unit]).uncancelable) + _ <- latch.get + _ <- f.cancel + } yield () + + pure.run(t) mustEqual Outcome.Succeeded(None) + } + + "hang when canceling uncancelable Deferred#get" in { + val t = for { + latch <- F.deferred[Unit] + hang <- F.deferred[Unit] + f <- F.start((latch.complete(()) *> hang.get).uncancelable) + _ <- latch.get + _ <- f.cancel + } yield () + + pure.run(t) mustEqual Outcome.Succeeded(None) + } + + "hang when canceling uncancelable Fiber#join" in { + val t = for { + latch <- F.deferred[Unit] + hang <- F.start(F.never[Unit]) + f <- F.start((latch.complete(()) *> hang.join).uncancelable) + _ <- latch.get + _ <- f.cancel + } yield () + + pure.run(t) mustEqual Outcome.Succeeded(None) + } + + "hang when canceling fiber blocked on cancel finalization" in { + val t = for { + targetStarted <- F.deferred[Unit] + finalizerStarted <- F.deferred[Unit] + target <- F.start( + (targetStarted.complete(()) *> F.never[Unit]) + .onCancel(finalizerStarted.complete(()) *> F.never[Unit])) + _ <- targetStarted.get + canceler <- F.start(target.cancel) + _ <- finalizerStarted.get + _ <- canceler.cancel + } yield () + + pure.run(t) mustEqual Outcome.Succeeded(None) + } + + "run finalizers in order" in { + val t = for { + results <- F.ref[String]("") + f <- F start { + F.canceled.onCancel(results.update(_ + "A")).onCancel(results.update(_ + "B")) + } + _ <- f.join + back <- results.get + } yield back + + pure.run(t) mustEqual Outcome.Succeeded(Some("AB")) + } + + "ignore cancelation of a fiber after racePair has completed" in { + val t = for { + finalized <- F.ref(0) + fiber <- F.start { + F.racePair(F.unit, F.never[Unit]).void.onCancel(finalized.update(_ + 1)) + } + _ <- fiber.join + _ <- fiber.cancel + _ <- F.cede + back <- finalized.get + } yield back + + pure.run(t) mustEqual Outcome.Succeeded(Some(0)) + } + + "ignore cancelation of a fiber after race has completed" in { + val t = for { + finalized <- F.ref(0) + fiber <- F.start { + F.race(F.unit, F.never[Unit]).void.onCancel(finalized.update(_ + 1)) + } + _ <- fiber.join + _ <- fiber.cancel + _ <- F.cede + back <- finalized.get + } yield back + + pure.run(t) mustEqual Outcome.Succeeded(Some(0)) + } + + "correctly interpret uncancelable cancelation followed by suspension" in { + val t = F.uncancelable(_ => F.canceled *> F.never[Unit]) + pure.run(t) mustEqual Outcome.Succeeded(None) + + val forked = pure.run(F.start(t).flatMap(_.joinWith(F.canceled *> F.never[Unit]))) + forked mustEqual Outcome.Succeeded(None) + } + + "ignore poll from another fiber" in { + val t = for { + started <- F.deferred[Unit] + polled <- F.deferred[Unit] + + parent <- F.start { + F.uncancelable { poll => + started.complete(()) *> + F.start(poll(polled.complete(()) *> F.never[Unit])).void *> + polled.get *> + F.never[Unit] + } + } + + _ <- started.get + _ <- polled.get + _ <- parent.cancel + } yield () + + pure.run(t) mustEqual Outcome.Succeeded(None) + } + + "observe external cancelation while blocked inside poll" in { + val t = for { + started <- F.deferred[Unit] + polled <- F.deferred[Unit] + gate <- F.deferred[Unit] + ran <- F.ref(false) + fiber <- F.start { + F.uncancelable { poll => + started.complete(()) *> + poll(polled.complete(()) *> gate.get *> ran.set(true)) + } + } + canceler <- F.start(polled.get *> fiber.cancel) + _ <- started.get + _ <- canceler.join + back <- ran.get + } yield back + + pure.run(t) mustEqual Outcome.Succeeded(Some(false)) + } + + "run finalizers around a self-canceling polled region" in { + val t = for { + finalized <- F.ref(0) + fiber <- F.start { + F.uncancelable { poll => F.onCancel(poll(F.canceled), finalized.update(_ + 1)) } + } + _ <- fiber.join + back <- finalized.get + } yield back + + pure.run(t) mustEqual Outcome.Succeeded(Some(1)) + } + + "run outer finalizers around a self-canceling polled region" in { + val t = for { + finalized <- F.ref(0) + fiber <- F.start { + F.onCancel(F.uncancelable { poll => poll(F.canceled) }, finalized.update(_ + 1)) + } + _ <- fiber.join + back <- finalized.get + } yield back + + pure.run(t) mustEqual Outcome.Succeeded(Some(1)) + } + + "run outer finalizers when a masked self-cancel is observed inside poll" in { + val t = for { + finalized <- F.ref(0) + fiber <- F.start { + F.onCancel( + F.uncancelable { poll => poll(F.uncancelable(_ => F.canceled)) }, + finalized.update(_ + 1)) + } + _ <- fiber.join + back <- finalized.get + } yield back + + pure.run(t) mustEqual Outcome.Succeeded(Some(1)) + } + + "observe pending self-cancel before running a polled region" in { + val t = for { + finalized <- F.ref(0) + ran <- F.ref(false) + fiber <- F.start { + F.uncancelable { poll => + F.canceled *> F.onCancel(poll(ran.set(true)), finalized.update(_ + 1)) + } + } + _ <- fiber.join + fin <- finalized.get + body <- ran.get + } yield (fin, body) + + pure.run(t) mustEqual Outcome.Succeeded(Some((1, false))) + } + + "observe nested self-cancel inside a polled region before continuing" in { + val t = for { + ran <- F.ref(false) + fiber <- F.start { + F.uncancelable { poll => + poll { + F.uncancelable(_ => F.canceled) *> ran.set(true) + } + } + } + _ <- fiber.join + back <- ran.get + } yield back + + pure.run(t) mustEqual Outcome.Succeeded(Some(false)) + } + + "implement locals via Kleisli and FreeT" in { + import cats.{~>, Eval, Id} + import cats.data.Kleisli + import cats.free.FreeT + import cats.syntax.all._ + + type F[A] = FreeT[Id, Kleisli[Eval, Int, *], A] + + def read[A](f: Int => F[A]): F[A] = + FreeT.liftT(Kleisli.ask[Eval, Int]).flatMap(f) + + def withLocal[A](i: Int)(fa: F[A]): F[A] = + fa.mapK(new (Kleisli[Eval, Int, *] ~> Kleisli[Eval, Int, *]) { + def apply[a](kea: Kleisli[Eval, Int, a]) = + Kleisli((_: Int) => kea(i)) + }) + + def run[A](i: Int)(fa: F[A]): A = + fa.runM(fta => Kleisli.liftF(Eval.now(fta))).apply(i).value + + val _ = run(1) { + withLocal(42) { + read { i => + FreeT + .liftT[Id, Kleisli[Eval, Int, *], Unit]( + Kleisli.liftF[Eval, Int, Unit](Eval.later { i mustEqual 42; () })) + .flatMap(_ => + read { i2 => FreeT.liftT(Kleisli.liftF(Eval.later(i2 mustEqual 42))) }) + } + } *> read { i => FreeT.liftT(Kleisli.liftF(Eval.later(i mustEqual 1))) } + } + + ok + } + + "race TimeT values against never" in { + type T[A] = TimeT[F, A] + val T = GenTemporal[T, Int] + + pure.run(TimeT.run(T.race(T.pure(1), T.never[Unit]))) mustEqual Outcome.Succeeded( + Some(Left(1))) + pure.run(TimeT.run(T.race(T.never[Unit], T.pure(1)))) mustEqual Outcome.Succeeded( + Some(Right(1))) + pure.run(TimeT.run(T.race(T.sleep(1.second).as(1), T.never[Unit]))) mustEqual Outcome + .Succeeded(Some(Left(1))) + pure.run(TimeT.run(T.race(T.never[Unit], T.sleep(1.second).as(1)))) mustEqual Outcome + .Succeeded(Some(Right(1))) + pure.run( + TimeT.run( + T.race(T.sleep(2.seconds).as("slow"), T.sleep(1.second).as("fast")))) mustEqual + Outcome.Succeeded(Some(Right("fast"))) + pure.run( + TimeT.run( + T.race(T.sleep(1.second).as("fast"), T.sleep(2.seconds).as("slow")))) mustEqual + Outcome.Succeeded(Some(Left("fast"))) + pure.run(TimeT.run(T.race(T.canceled, T.never[Unit]).void)) mustEqual Outcome.Canceled() + pure.run(TimeT.run(T.race(T.never[Unit], T.canceled).void)) mustEqual Outcome.Canceled() + pure.run( + TimeT.run( + T.race(TimeT.liftF(F.uncancelable(_ => F.canceled.as(1))), T.never[Unit]))) mustEqual + Outcome.Canceled() + pure.run( + TimeT.run( + T.race(T.never[Unit], TimeT.liftF(F.uncancelable(_ => F.canceled.as(1)))))) mustEqual + Outcome.Canceled() + pure.run( + TimeT.run( + T.race(TimeT.liftF(F.start(F.unit).flatMap(_.join).as(1)), T.never[Unit]))) mustEqual + Outcome.Succeeded(Some(Left(1))) + pure.run( + TimeT.run( + T.race(T.never[Unit], TimeT.liftF(F.start(F.unit).flatMap(_.join).as(1))))) mustEqual + Outcome.Succeeded(Some(Right(1))) + } + + } + checkAll( "TimeT[PureConc]", GenTemporalTests[TimeT[PureConc[Int, *], *], Int].temporal[Int, Int, Int](10.millis)