diff --git a/tests/tchannels_stop_timeout.nim b/tests/tchannels_stop_timeout.nim new file mode 100644 index 0000000..324fc6d --- /dev/null +++ b/tests/tchannels_stop_timeout.nim @@ -0,0 +1,75 @@ +import threading/channels +import std/[os, times, isolation] + +type + RecvPayload = tuple[ch: Chan[int], done: Chan[bool]] + SendPayload = tuple[ch: Chan[int], done: Chan[bool]] + +proc recvWorker(p: RecvPayload) {.thread.} = + var value: int + let ok = p.ch.recv(value) + discard p.done.send(ok) + +proc sendWorker(p: SendPayload) {.thread.} = + let ok = p.ch.send(42) + discard p.done.send(ok) + +var destroyedCount = 0 + +type + DrainProbe = object + id: int + +proc `=destroy`(x: var DrainProbe) = + atomicInc(destroyedCount) + +block stop_unblocks_recv: + var ch = newChan[int](elements = 1) + var done = newChan[bool](elements = 1) + var thread: Thread[RecvPayload] + createThread(thread, recvWorker, (ch, done)) + sleep(50) + + doAssert not ch.stopToken() + ch.stop() + doAssert ch.stopToken() + + var recvOk = true + doAssert done.recv(recvOk, timeout = initDuration(milliseconds = 500)) + doAssert recvOk == false + thread.joinThread() + +block stop_unblocks_send: + var ch = newChan[int](elements = 1) + doAssert ch.send(1) + + var done = newChan[bool](elements = 1) + var thread: Thread[SendPayload] + createThread(thread, sendWorker, (ch, done)) + sleep(50) + + ch.stop() + + var sendOk = true + doAssert done.recv(sendOk, timeout = initDuration(milliseconds = 500)) + doAssert sendOk == false + thread.joinThread() + +block recv_timeout: + var ch = newChan[int](elements = 1) + var value: int + doAssert not ch.recv(value, timeout = initDuration(milliseconds = 20)) + +block send_timeout: + var ch = newChan[int](elements = 1) + doAssert ch.send(1) + doAssert not ch.send(2, timeout = initDuration(milliseconds = 20)) + +block destroy_drains_pending_items: + let baseline = destroyedCount + block: + var ch = newChan[DrainProbe](elements = 8) + for i in 0..<3: + var iso = isolate(DrainProbe(id: i)) + doAssert ch.tryTake(iso) + doAssert destroyedCount - baseline >= 3 diff --git a/threading/channels.nim b/threading/channels.nim index f9d7529..a2e3858 100644 --- a/threading/channels.nim +++ b/threading/channels.nim @@ -100,7 +100,8 @@ runnableExamples("--threads:on --gc:orc"): when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc) or defined(nimdoc)): {.error: "This module requires one of --mm:arc / --mm:atomicArc / --mm:orc compilation flags".} -import std/[locks, isolation, atomics] +import std/[locks, isolation, atomics, times] +from std/os import sleep # Channel # ------------------------------------------------------------------------------ @@ -108,11 +109,12 @@ import std/[locks, isolation, atomics] type ChannelRaw = ptr ChannelObj ChannelObj = object - lock: Lock - spaceAvailableCV, dataAvailableCV: Cond + L: Lock + spaceAvailable, dataAvailable: Cond slots: int ## Number of item slots in the buffer head: Atomic[int] ## Write/enqueue/send index tail: Atomic[int] ## Read/dequeue/receive index + stopToken: Atomic[bool] buffer: ptr UncheckedArray[byte] atomicCounter: Atomic[int] @@ -133,6 +135,12 @@ proc setHead(chan: ChannelRaw, value: int, order: MemoryOrder = moRelaxed) {.inl proc setAtomicCounter(chan: ChannelRaw, value: int, order: MemoryOrder = moRelaxed) {.inline.} = chan.atomicCounter.store(value, order) +proc setStopToken(chan: ChannelRaw, value: bool, order: MemoryOrder = moRelaxed) {.inline.} = + chan.stopToken.store(value, order) + +proc getStopToken(chan: ChannelRaw, order: MemoryOrder = moRelaxed): bool {.inline.} = + chan.stopToken.load(order) + proc numItems(chan: ChannelRaw): int {.inline.} = result = chan.getHead() - chan.getTail() if result < 0: @@ -146,6 +154,9 @@ template isFull(chan: ChannelRaw): bool = template isEmpty(chan: ChannelRaw): bool = chan.getHead() == chan.getTail() +template isStopped(chan: ChannelRaw): bool = + chan.getStopToken() + # Channels memory ops # ------------------------------------------------------------------------------ @@ -155,13 +166,14 @@ proc allocChannel(size, n: int): ChannelRaw = # To buffer n items, we allocate for n result.buffer = cast[ptr UncheckedArray[byte]](allocShared(n*size)) - initLock(result.lock) - initCond(result.spaceAvailableCV) - initCond(result.dataAvailableCV) + initLock(result.L) + initCond(result.spaceAvailable) + initCond(result.dataAvailable) result.slots = n result.setHead(0) result.setTail(0) + result.setStopToken(false) result.setAtomicCounter(0) proc freeChannel(chan: ChannelRaw) = @@ -171,33 +183,81 @@ proc freeChannel(chan: ChannelRaw) = if not chan.buffer.isNil: deallocShared(chan.buffer) - deinitLock(chan.lock) - deinitCond(chan.spaceAvailableCV) - deinitCond(chan.dataAvailableCV) + deinitLock(chan.L) + deinitCond(chan.spaceAvailable) + deinitCond(chan.dataAvailable) deallocShared(chan) # MPMC Channels (Multi-Producer Multi-Consumer) # ------------------------------------------------------------------------------ -proc channelSend(chan: ChannelRaw, data: pointer, size: int, blocking: static bool): bool = +const + TimeoutPollMs = 1 + +proc stopRaw(chan: ChannelRaw) = + if chan.isNil: + return + acquire(chan.L) + chan.setStopToken(true) + broadcast(chan.spaceAvailable) + broadcast(chan.dataAvailable) + release(chan.L) + +proc drainChannel[T](chan: ChannelRaw) = + if chan.isNil: + return + acquire(chan.L) + while not chan.isEmpty(): + let readIdx = if chan.getTail() < chan.slots: + chan.getTail() + else: + chan.getTail() - chan.slots + let slot = cast[ptr T](chan.buffer[readIdx * sizeof(T)].addr) + `=destroy`(slot[]) + + var nextTail = chan.getTail() + 1 + if nextTail == 2 * chan.slots: + nextTail = 0 + chan.setTail(nextTail) + release(chan.L) + +proc channelSend(chan: ChannelRaw, data: pointer, size: int, blocking: static bool, + timeout = default(Duration)): bool = assert not chan.isNil assert not data.isNil when not blocking: - if chan.isFull(): return false + if chan.isFull() or chan.isStopped(): + return false - acquire(chan.lock) + acquire(chan.L) # check for when another thread was faster to fill when blocking: + let useTimeout = timeout != default(Duration) + let startedAt = if useTimeout: getTime() else: default(Time) while chan.isFull(): - wait(chan.spaceAvailableCV, chan.lock) + if chan.isStopped(): + release(chan.L) + return false + if useTimeout: + release(chan.L) + if timeout <= (getTime() - startedAt): + return false + sleep(TimeoutPollMs) + acquire(chan.L) + else: + wait(chan.spaceAvailable, chan.L) else: - if chan.isFull(): - release(chan.lock) + if chan.isFull() or chan.isStopped(): + release(chan.L) return false + if chan.isStopped(): + release(chan.L) + return false + assert not chan.isFull() let writeIdx = if chan.getHead() < chan.slots: @@ -206,30 +266,44 @@ proc channelSend(chan: ChannelRaw, data: pointer, size: int, blocking: static bo chan.getHead() - chan.slots copyMem(chan.buffer[writeIdx * size].addr, data, size) + atomicInc(chan.head) if chan.getHead() == 2 * chan.slots: chan.setHead(0) - signal(chan.dataAvailableCV) - release(chan.lock) + signal(chan.dataAvailable) + release(chan.L) result = true -proc channelReceive(chan: ChannelRaw, data: pointer, size: int, blocking: static bool): bool = +proc channelReceive(chan: ChannelRaw, data: pointer, size: int, blocking: static bool, + timeout = default(Duration)): bool = assert not chan.isNil assert not data.isNil when not blocking: if chan.isEmpty(): return false - acquire(chan.lock) + acquire(chan.L) # check for when another thread was faster to empty when blocking: + let useTimeout = timeout != default(Duration) + let startedAt = if useTimeout: getTime() else: default(Time) while chan.isEmpty(): - wait(chan.dataAvailableCV, chan.lock) + if chan.isStopped(): + release(chan.L) + return false + if useTimeout: + release(chan.L) + if timeout <= (getTime() - startedAt): + return false + sleep(TimeoutPollMs) + acquire(chan.L) + else: + wait(chan.dataAvailable, chan.L) else: if chan.isEmpty(): - release(chan.lock) + release(chan.L) return false assert not chan.isEmpty() @@ -245,8 +319,8 @@ proc channelReceive(chan: ChannelRaw, data: pointer, size: int, blocking: static if chan.getTail() == 2 * chan.slots: chan.setTail(0) - signal(chan.spaceAvailableCV) - release(chan.lock) + signal(chan.spaceAvailable) + release(chan.L) result = true # Public API @@ -261,6 +335,7 @@ template frees(c) = # this `fetchSub` returns current val then subs # so count == 0 means we're the last if c.d.atomicCounter.fetchSub(1, moAcquireRelease) == 0: + drainChannel[T](c.d) freeChannel(c.d) when defined(nimAllowNonVarDestructor): @@ -343,7 +418,8 @@ proc tryRecv*[T](c: Chan[T], dst: var T): bool {.inline.} = ## Returns `false` and does not change `dist` if no message was received. channelReceive(c.d, dst.addr, sizeof(T), false) -proc send*[T](c: Chan[T], src: sink Isolated[T]) {.inline.} = +proc send*[T](c: Chan[T], src: sink Isolated[T], + timeout = default(Duration)): bool {.inline, discardable.} = ## Sends the message `src` to the channel `c`. ## This blocks the sending thread until `src` was successfully sent. ## @@ -351,39 +427,59 @@ proc send*[T](c: Chan[T], src: sink Isolated[T]) {.inline.} = ## ## If the channel is already full with messages this will block the thread until ## messages from the channel are removed. + ## + ## If `timeout` is provided and it expires, or if the channel has been stopped, + ## this returns `false` and does not send the message. when defined(gcOrc) and defined(nimSafeOrcSend): GC_runOrc() - discard channelSend(c.d, src.addr, sizeof(T), true) - wasMoved(src) + result = channelSend(c.d, src.addr, sizeof(T), true, timeout) + if result: + wasMoved(src) -template send*[T](c: Chan[T]; src: T) = +template send*[T](c: Chan[T]; src: T; timeout = default(Duration)): bool = ## Helper template for `send`. mixin isolate - send(c, isolate(src)) + send(c, isolate(src), timeout) -proc recv*[T](c: Chan[T], dst: var T) {.inline.} = +proc recv*[T](c: Chan[T], dst: var T, + timeout = default(Duration)): bool {.inline, discardable.} = ## Receives a message from the channel `c` and fill `dst` with its value. ## ## This blocks the receiving thread until a message was successfully received. ## ## If the channel does not contain any messages this will block the thread until ## a message get sent to the channel. - discard channelReceive(c.d, dst.addr, sizeof(T), true) + ## + ## If `timeout` is provided and it expires, or if the channel has been stopped + ## and no data is available, this returns `false`. + channelReceive(c.d, dst.addr, sizeof(T), true, timeout) proc recv*[T](c: Chan[T]): T {.inline.} = ## Receives a message from the channel. ## A version of `recv`_ that returns the message. - discard channelReceive(c.d, result.addr, sizeof(T), true) + let ok = channelReceive(c.d, result.addr, sizeof(T), true) + if not ok: + raise newException(ValueError, "channel stopped") proc recvIso*[T](c: Chan[T]): Isolated[T] {.inline.} = ## Receives a message from the channel. ## A version of `recv`_ that returns the message and isolates it. - discard channelReceive(c.d, result.addr, sizeof(T), true) + let ok = channelReceive(c.d, result.addr, sizeof(T), true) + if not ok: + raise newException(ValueError, "channel stopped") proc peek*[T](c: Chan[T]): int {.inline.} = ## Returns an estimation of the current number of messages held by the channel. numItems(c.d) +proc stop*[T](c: Chan[T]) {.inline.} = + ## Stops the channel and wakes any waiting send/receive operations. + stopRaw(c.d) + +proc stopToken*[T](c: Chan[T]): bool {.inline.} = + ## Returns whether this channel has been stopped. + not c.d.isNil and c.d.getStopToken() + proc newChan*[T](elements: Positive = 30): Chan[T] = ## An initialization procedure, necessary for acquiring resources and ## initializing internal state of the channel.