Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/buffer/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl<T, const CAP: usize> Buffer<T> for ArrayBuffer<T, CAP> {

#[inline(always)]
fn at(&self, idx: usize) -> *const UnsafeCell<MaybeUninit<T>> {
&self.0[idx % CAP] as * const _
&self.0[idx % CAP] as *const _
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/buffer/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::Buffer;

/// Holds data allocated from the heap at run time
pub struct DynamicBuffer<T> {
items: Box<[UnsafeCell<MaybeUninit<T>>]>
items: Box<[UnsafeCell<MaybeUninit<T>>]>,
}

impl<T> DynamicBuffer<T> {
Expand All @@ -18,7 +18,7 @@ impl<T> DynamicBuffer<T> {
let mut vec = Vec::with_capacity(size);
unsafe { vec.set_len(size) };
Ok(DynamicBuffer {
items: vec.into_boxed_slice()
items: vec.into_boxed_slice(),
})
} else {
Err("Buffer size must be greater than 0")
Expand Down Expand Up @@ -46,7 +46,7 @@ impl<T> Buffer<T> for DynamicBuffer<T> {
/// faster runtime performance due to the use of a mask instead of modulus
/// when computing buffer indexes.
pub struct DynamicBufferP2<T> {
items: Box<[UnsafeCell<MaybeUninit<T>>]>
items: Box<[UnsafeCell<MaybeUninit<T>>]>,
}

impl<T> DynamicBufferP2<T> {
Expand All @@ -61,7 +61,7 @@ impl<T> DynamicBufferP2<T> {
let mut vec = Vec::with_capacity(size);
unsafe { vec.set_len(size) };
vec.into_boxed_slice()
}
},
}),
_ => Err("Buffer size must be a power of two"),
}
Expand Down
2 changes: 1 addition & 1 deletion src/buffer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl<T, B: Buffer<T>> Buffer<T> for Box<B> {
(**self).size()
}

fn at(&self, idx: usize) -> * const UnsafeCell<MaybeUninit<T>> {
fn at(&self, idx: usize) -> *const UnsafeCell<MaybeUninit<T>> {
(**self).at(idx)
}
}
10 changes: 10 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,30 @@ impl std::error::Error for TryPopError {}
/// The consumer end of the queue allows for sending data. `Producer<T>` is
/// always `Send`, but is only `Sync` for multi-producer (MPSC, MPMC) queues.
pub trait Producer<T> {
/// Check if this channel is closed.
fn is_closed(&self) -> bool;

/// Add value to front of the queue. This method will block if the queue
/// is currently full.
/// If the channel is closed, this function will continue succeeding until
/// the queue becomes full.
fn push(&self, value: T) -> Result<(), PushError<T>>;

/// Attempt to add a value to the front of the queue. If the value was
/// added successfully, `None` will be returned. If unsuccessful, `value`
/// will be returned. An unsuccessful push indicates that the queue was
/// full.
/// If the channel is closed, this function will continue succeeding until
/// the queue becomes full.
fn try_push(&self, value: T) -> Result<(), TryPushError<T>>;
}

/// The consumer end of the queue allows for receiving data. `Consumer<T>` is
/// always `Send`, but is only `Sync` for multi-consumer (SPMC, MPMC) queues.
pub trait Consumer<T> {
/// Check if this channel is closed.
fn is_closed(&self) -> bool;

/// Remove value from the end of the queue. This method will block if the
/// queue is currently empty.
fn pop(&self) -> Result<T, PopError>;
Expand Down
87 changes: 53 additions & 34 deletions src/mpmc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ unsafe impl<T: Send, B: Buffer<T>> Sync for MPMCConsumer<T, B> {}

impl<T, B: Buffer<T>> Clone for MPMCConsumer<T, B> {
fn clone(&self) -> Self {
self.queue.consumers.fetch_add(1, Ordering::Release);
self.queue.consumers.fetch_add(1, Ordering::Relaxed);
MPMCConsumer {
queue: self.queue.clone(),
}
Expand All @@ -50,7 +50,7 @@ unsafe impl<T: Send, B: Buffer<T>> Sync for MPMCProducer<T, B> {}

impl<T, B: Buffer<T>> Clone for MPMCProducer<T, B> {
fn clone(&self) -> Self {
self.queue.producers.fetch_add(1, Ordering::Release);
self.queue.producers.fetch_add(1, Ordering::Relaxed);
MPMCProducer {
queue: self.queue.clone(),
}
Expand Down Expand Up @@ -102,21 +102,24 @@ impl<T, B: Buffer<T>> Drop for MPMCQueue<T, B> {
}

impl<T, B: Buffer<T>> Producer<T> for MPMCProducer<T, B> {
fn is_closed(&self) -> bool {
self.queue.consumers.load(Ordering::Relaxed) == 0
}

fn push(&self, value: T) -> Result<(), PushError<T>> {
let q = &self.queue;

let head = q.head.next.fetch_add(1, Ordering::Relaxed);

loop {
if q.consumers.load(Ordering::Acquire) == 0 {
return Err(PushError::Disconnected(value));
} else if q.tail.curr.load(Ordering::Acquire) + q.buf.size() > head {
if q.tail.curr.load(Ordering::Acquire) + q.buf.size() > head {
break;
} else if q.consumers.load(Ordering::Relaxed) == 0 {
return Err(PushError::Disconnected(value));
}
spin_loop();
}

unsafe { buf_write(&q.buf, head, value) };

while q.head.curr.load(Ordering::Relaxed) < head {
spin_loop();
}
Expand All @@ -128,43 +131,50 @@ impl<T, B: Buffer<T>> Producer<T> for MPMCProducer<T, B> {
let q = &self.queue;
loop {
let head = q.head.curr.load(Ordering::Relaxed);
if q.consumers.load(Ordering::Acquire) == 0 {
return Err(TryPushError::Disconnected(value));
} else if q.tail.curr.load(Ordering::Acquire) + q.buf.size() <= head {
return Err(TryPushError::Full(value));
} else {
let next = head + 1;
if q.head
.next
.compare_exchange_weak(head, next, Ordering::Acquire, Ordering::Acquire)
.is_ok()
{
unsafe { buf_write(&q.buf, head, value) };
q.head.curr.store(next, Ordering::Release);
return Ok(());
}
let head_plus_one = head + 1;

if q.tail.curr.load(Ordering::Acquire) + q.buf.size() <= head {
// buffer is full, check whether it's closed.
// relaxed is fine since Consumer.drop does an acquire/release on .tail
return if q.consumers.load(Ordering::Relaxed) == 0 {
Err(TryPushError::Disconnected(value))
} else {
Err(TryPushError::Full(value))
};
} else if q
.head
.next
.compare_exchange_weak(head, head_plus_one, Ordering::Acquire, Ordering::Acquire)
.is_ok()
{
unsafe { buf_write(&q.buf, head, value) };
q.head.curr.store(head_plus_one, Ordering::Release);
return Ok(());
}
}
}
}

impl<T, B: Buffer<T>> Consumer<T> for MPMCConsumer<T, B> {
fn is_closed(&self) -> bool {
self.queue.producers.load(Ordering::Relaxed) == 0
}

fn pop(&self) -> Result<T, PopError> {
let q = &self.queue;

let tail = q.tail.next.fetch_add(1, Ordering::Relaxed);
let tail_plus_one = tail + 1;

loop {
if tail_plus_one <= q.head.curr.load(Ordering::Acquire) {
if q.head.curr.load(Ordering::Acquire) >= tail_plus_one {
break;
} else if q.producers.load(Ordering::Acquire) == 0 {
} else if q.producers.load(Ordering::Relaxed) == 0 {
return Err(PopError::Disconnected);
}
spin_loop();
}

let v = unsafe { buf_read(&q.buf, tail) };

while q.tail.curr.load(Ordering::Relaxed) < tail {
spin_loop();
}
Expand All @@ -177,12 +187,15 @@ impl<T, B: Buffer<T>> Consumer<T> for MPMCConsumer<T, B> {
loop {
let tail = q.tail.curr.load(Ordering::Relaxed);
let tail_plus_one = tail + 1;
if tail_plus_one > q.head.curr.load(Ordering::Acquire) {
if q.producers.load(Ordering::Acquire) > 0 {
return Err(TryPopError::Empty);

if q.head.curr.load(Ordering::Acquire) < tail_plus_one {
// buffer is empty, check whether it's closed.
// relaxed is fine since Producer.drop does an acquire/release on .head
return if q.producers.load(Ordering::Relaxed) == 0 {
Err(TryPopError::Disconnected)
} else {
return Err(TryPopError::Disconnected);
}
Err(TryPopError::Empty)
};
} else if q
.tail
.next
Expand All @@ -199,13 +212,17 @@ impl<T, B: Buffer<T>> Consumer<T> for MPMCConsumer<T, B> {

impl<T, B: Buffer<T>> Drop for MPMCProducer<T, B> {
fn drop(&mut self) {
self.queue.producers.fetch_sub(1, Ordering::Release);
self.queue.producers.fetch_sub(1, Ordering::Relaxed);
// Acquire/Release .head to ensure other threads see new .closed
self.queue.head.curr.fetch_add(0, Ordering::AcqRel);
}
}

impl<T, B: Buffer<T>> Drop for MPMCConsumer<T, B> {
fn drop(&mut self) {
self.queue.consumers.fetch_sub(1, Ordering::Release);
self.queue.consumers.fetch_sub(1, Ordering::Relaxed);
// Acquire/Release .tail to ensure other threads see new .closed
self.queue.tail.curr.fetch_add(0, Ordering::AcqRel);
}
}

Expand Down Expand Up @@ -308,9 +325,11 @@ mod test {
assert_eq!(c.pop(), Err(PopError::Disconnected));
assert_eq!(c.try_pop(), Err(TryPopError::Disconnected));

let (p, c) = mpmc_queue(DynamicBuffer::new(32).unwrap());
let (p, c) = mpmc_queue(DynamicBuffer::new(2).unwrap());
p.push(1).unwrap();
std::mem::drop(c);
assert!(p.is_closed());
p.push(1).unwrap();
assert_eq!(p.push(2), Err(PushError::Disconnected(2)));
assert_eq!(p.try_push(2), Err(TryPushError::Disconnected(2)));

Expand Down
Loading