Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/tchannels_stop_timeout.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import threading/channels
import std/[os, times, isolation]

type
RecvPayload = tuple[ch: Chan[int], done: Chan[bool]]
SendPayload = tuple[ch: Chan[int], done: Chan[bool]]

proc recvWorker(p: RecvPayload) {.thread.} =
var value: int
let ok = p.ch.recv(value)
discard p.done.send(ok)

proc sendWorker(p: SendPayload) {.thread.} =
let ok = p.ch.send(42)
discard p.done.send(ok)

var destroyedCount = 0

type
DrainProbe = object
id: int

proc `=destroy`(x: var DrainProbe) =
atomicInc(destroyedCount)

block stop_unblocks_recv:
var ch = newChan[int](elements = 1)
var done = newChan[bool](elements = 1)
var thread: Thread[RecvPayload]
createThread(thread, recvWorker, (ch, done))
sleep(50)

doAssert not ch.stopToken()
ch.stop()
doAssert ch.stopToken()

var recvOk = true
doAssert done.recv(recvOk, timeout = initDuration(milliseconds = 500))
doAssert recvOk == false
thread.joinThread()

block stop_unblocks_send:
var ch = newChan[int](elements = 1)
doAssert ch.send(1)

var done = newChan[bool](elements = 1)
var thread: Thread[SendPayload]
createThread(thread, sendWorker, (ch, done))
sleep(50)

ch.stop()

var sendOk = true
doAssert done.recv(sendOk, timeout = initDuration(milliseconds = 500))
doAssert sendOk == false
thread.joinThread()

block recv_timeout:
var ch = newChan[int](elements = 1)
var value: int
doAssert not ch.recv(value, timeout = initDuration(milliseconds = 20))

block send_timeout:
var ch = newChan[int](elements = 1)
doAssert ch.send(1)
doAssert not ch.send(2, timeout = initDuration(milliseconds = 20))

block destroy_drains_pending_items:
let baseline = destroyedCount
block:
var ch = newChan[DrainProbe](elements = 8)
for i in 0..<3:
var iso = isolate(DrainProbe(id: i))
doAssert ch.tryTake(iso)
doAssert destroyedCount - baseline >= 3
160 changes: 128 additions & 32 deletions threading/channels.nim
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,21 @@ runnableExamples("--threads:on --gc:orc"):
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc) or defined(nimdoc)):
{.error: "This module requires one of --mm:arc / --mm:atomicArc / --mm:orc compilation flags".}

import std/[locks, isolation, atomics]
import std/[locks, isolation, atomics, times]
from std/os import sleep

# Channel
# ------------------------------------------------------------------------------

type
ChannelRaw = ptr ChannelObj
ChannelObj = object
lock: Lock
spaceAvailableCV, dataAvailableCV: Cond
L: Lock
spaceAvailable, dataAvailable: Cond
slots: int ## Number of item slots in the buffer
head: Atomic[int] ## Write/enqueue/send index
tail: Atomic[int] ## Read/dequeue/receive index
stopToken: Atomic[bool]
buffer: ptr UncheckedArray[byte]
atomicCounter: Atomic[int]

Expand All @@ -133,6 +135,12 @@ proc setHead(chan: ChannelRaw, value: int, order: MemoryOrder = moRelaxed) {.inl
proc setAtomicCounter(chan: ChannelRaw, value: int, order: MemoryOrder = moRelaxed) {.inline.} =
chan.atomicCounter.store(value, order)

proc setStopToken(chan: ChannelRaw, value: bool, order: MemoryOrder = moRelaxed) {.inline.} =
chan.stopToken.store(value, order)

proc getStopToken(chan: ChannelRaw, order: MemoryOrder = moRelaxed): bool {.inline.} =
chan.stopToken.load(order)

proc numItems(chan: ChannelRaw): int {.inline.} =
result = chan.getHead() - chan.getTail()
if result < 0:
Expand All @@ -146,6 +154,9 @@ template isFull(chan: ChannelRaw): bool =
template isEmpty(chan: ChannelRaw): bool =
chan.getHead() == chan.getTail()

template isStopped(chan: ChannelRaw): bool =
chan.getStopToken()

# Channels memory ops
# ------------------------------------------------------------------------------

Expand All @@ -155,13 +166,14 @@ proc allocChannel(size, n: int): ChannelRaw =
# To buffer n items, we allocate for n
result.buffer = cast[ptr UncheckedArray[byte]](allocShared(n*size))

initLock(result.lock)
initCond(result.spaceAvailableCV)
initCond(result.dataAvailableCV)
initLock(result.L)
initCond(result.spaceAvailable)
initCond(result.dataAvailable)

result.slots = n
result.setHead(0)
result.setTail(0)
result.setStopToken(false)
result.setAtomicCounter(0)

proc freeChannel(chan: ChannelRaw) =
Expand All @@ -171,33 +183,81 @@ proc freeChannel(chan: ChannelRaw) =
if not chan.buffer.isNil:
deallocShared(chan.buffer)

deinitLock(chan.lock)
deinitCond(chan.spaceAvailableCV)
deinitCond(chan.dataAvailableCV)
deinitLock(chan.L)
deinitCond(chan.spaceAvailable)
deinitCond(chan.dataAvailable)

deallocShared(chan)

# MPMC Channels (Multi-Producer Multi-Consumer)
# ------------------------------------------------------------------------------

proc channelSend(chan: ChannelRaw, data: pointer, size: int, blocking: static bool): bool =
const
TimeoutPollMs = 1

proc stopRaw(chan: ChannelRaw) =
if chan.isNil:
return
acquire(chan.L)
chan.setStopToken(true)
broadcast(chan.spaceAvailable)
broadcast(chan.dataAvailable)
release(chan.L)

proc drainChannel[T](chan: ChannelRaw) =
if chan.isNil:
return
acquire(chan.L)
while not chan.isEmpty():
let readIdx = if chan.getTail() < chan.slots:
chan.getTail()
else:
chan.getTail() - chan.slots
let slot = cast[ptr T](chan.buffer[readIdx * sizeof(T)].addr)
`=destroy`(slot[])

var nextTail = chan.getTail() + 1
if nextTail == 2 * chan.slots:
nextTail = 0
chan.setTail(nextTail)
release(chan.L)

proc channelSend(chan: ChannelRaw, data: pointer, size: int, blocking: static bool,
timeout = default(Duration)): bool =
assert not chan.isNil
assert not data.isNil

when not blocking:
if chan.isFull(): return false
if chan.isFull() or chan.isStopped():
return false

acquire(chan.lock)
acquire(chan.L)

# check for when another thread was faster to fill
when blocking:
let useTimeout = timeout != default(Duration)
let startedAt = if useTimeout: getTime() else: default(Time)
while chan.isFull():
wait(chan.spaceAvailableCV, chan.lock)
if chan.isStopped():
release(chan.L)
return false
if useTimeout:
release(chan.L)
if timeout <= (getTime() - startedAt):
return false
sleep(TimeoutPollMs)
acquire(chan.L)
else:
wait(chan.spaceAvailable, chan.L)
else:
if chan.isFull():
release(chan.lock)
if chan.isFull() or chan.isStopped():
release(chan.L)
return false

if chan.isStopped():
release(chan.L)
return false

assert not chan.isFull()

let writeIdx = if chan.getHead() < chan.slots:
Expand All @@ -206,30 +266,44 @@ proc channelSend(chan: ChannelRaw, data: pointer, size: int, blocking: static bo
chan.getHead() - chan.slots

copyMem(chan.buffer[writeIdx * size].addr, data, size)

atomicInc(chan.head)
if chan.getHead() == 2 * chan.slots:
chan.setHead(0)

signal(chan.dataAvailableCV)
release(chan.lock)
signal(chan.dataAvailable)
release(chan.L)
result = true

proc channelReceive(chan: ChannelRaw, data: pointer, size: int, blocking: static bool): bool =
proc channelReceive(chan: ChannelRaw, data: pointer, size: int, blocking: static bool,
timeout = default(Duration)): bool =
assert not chan.isNil
assert not data.isNil

when not blocking:
if chan.isEmpty(): return false

acquire(chan.lock)
acquire(chan.L)

# check for when another thread was faster to empty
when blocking:
let useTimeout = timeout != default(Duration)
let startedAt = if useTimeout: getTime() else: default(Time)
while chan.isEmpty():
wait(chan.dataAvailableCV, chan.lock)
if chan.isStopped():
release(chan.L)
return false
if useTimeout:
release(chan.L)
if timeout <= (getTime() - startedAt):
return false
sleep(TimeoutPollMs)
acquire(chan.L)
else:
wait(chan.dataAvailable, chan.L)
else:
if chan.isEmpty():
release(chan.lock)
release(chan.L)
return false

assert not chan.isEmpty()
Expand All @@ -245,8 +319,8 @@ proc channelReceive(chan: ChannelRaw, data: pointer, size: int, blocking: static
if chan.getTail() == 2 * chan.slots:
chan.setTail(0)

signal(chan.spaceAvailableCV)
release(chan.lock)
signal(chan.spaceAvailable)
release(chan.L)
result = true

# Public API
Expand All @@ -261,6 +335,7 @@ template frees(c) =
# this `fetchSub` returns current val then subs
# so count == 0 means we're the last
if c.d.atomicCounter.fetchSub(1, moAcquireRelease) == 0:
drainChannel[T](c.d)
freeChannel(c.d)

when defined(nimAllowNonVarDestructor):
Expand Down Expand Up @@ -343,47 +418,68 @@ proc tryRecv*[T](c: Chan[T], dst: var T): bool {.inline.} =
## Returns `false` and does not change `dist` if no message was received.
channelReceive(c.d, dst.addr, sizeof(T), false)

proc send*[T](c: Chan[T], src: sink Isolated[T]) {.inline.} =
proc send*[T](c: Chan[T], src: sink Isolated[T],
timeout = default(Duration)): bool {.inline, discardable.} =
## Sends the message `src` to the channel `c`.
## This blocks the sending thread until `src` was successfully sent.
##
## The memory of `src` is moved, not copied.
##
## If the channel is already full with messages this will block the thread until
## messages from the channel are removed.
##
## If `timeout` is provided and it expires, or if the channel has been stopped,
## this returns `false` and does not send the message.
when defined(gcOrc) and defined(nimSafeOrcSend):
GC_runOrc()
discard channelSend(c.d, src.addr, sizeof(T), true)
wasMoved(src)
result = channelSend(c.d, src.addr, sizeof(T), true, timeout)
if result:
wasMoved(src)

template send*[T](c: Chan[T]; src: T) =
template send*[T](c: Chan[T]; src: T; timeout = default(Duration)): bool =
## Helper template for `send`.
mixin isolate
send(c, isolate(src))
send(c, isolate(src), timeout)

proc recv*[T](c: Chan[T], dst: var T) {.inline.} =
proc recv*[T](c: Chan[T], dst: var T,
timeout = default(Duration)): bool {.inline, discardable.} =
## Receives a message from the channel `c` and fill `dst` with its value.
##
## This blocks the receiving thread until a message was successfully received.
##
## If the channel does not contain any messages this will block the thread until
## a message get sent to the channel.
discard channelReceive(c.d, dst.addr, sizeof(T), true)
##
## If `timeout` is provided and it expires, or if the channel has been stopped
## and no data is available, this returns `false`.
channelReceive(c.d, dst.addr, sizeof(T), true, timeout)

proc recv*[T](c: Chan[T]): T {.inline.} =
## Receives a message from the channel.
## A version of `recv`_ that returns the message.
discard channelReceive(c.d, result.addr, sizeof(T), true)
let ok = channelReceive(c.d, result.addr, sizeof(T), true)
if not ok:
raise newException(ValueError, "channel stopped")

proc recvIso*[T](c: Chan[T]): Isolated[T] {.inline.} =
## Receives a message from the channel.
## A version of `recv`_ that returns the message and isolates it.
discard channelReceive(c.d, result.addr, sizeof(T), true)
let ok = channelReceive(c.d, result.addr, sizeof(T), true)
if not ok:
raise newException(ValueError, "channel stopped")

proc peek*[T](c: Chan[T]): int {.inline.} =
## Returns an estimation of the current number of messages held by the channel.
numItems(c.d)

proc stop*[T](c: Chan[T]) {.inline.} =
## Stops the channel and wakes any waiting send/receive operations.
stopRaw(c.d)

proc stopToken*[T](c: Chan[T]): bool {.inline.} =
## Returns whether this channel has been stopped.
not c.d.isNil and c.d.getStopToken()

proc newChan*[T](elements: Positive = 30): Chan[T] =
## An initialization procedure, necessary for acquiring resources and
## initializing internal state of the channel.
Expand Down
Loading