diff --git a/src/ailego/buffer/parquet_hash_table.cc b/src/ailego/buffer/parquet_hash_table.cc index ab519843e..be8833976 100644 --- a/src/ailego/buffer/parquet_hash_table.cc +++ b/src/ailego/buffer/parquet_hash_table.cc @@ -141,11 +141,18 @@ ParquetBufferContextHandle ParquetBufferPool::acquire_buffer( return ParquetBufferContextHandle(); } std::unique_lock lock(table_mutex_); - if (acquire(buffer_id, table_[buffer_id]).ok()) { - MemoryLimitPool::get_instance().acquire_parquet(table_[buffer_id].size); - arrow = set_block_acquired(buffer_id); + auto [iter, inserted] = table_.try_emplace(buffer_id); + ParquetBufferContext &context = iter->second; + if (!inserted) { + arrow = set_block_acquired(context); + return ParquetBufferContextHandle(buffer_id, arrow); + } + if (acquire(buffer_id, context).ok()) { + MemoryLimitPool::get_instance().acquire_parquet(context.size); + arrow = set_block_acquired(context); return ParquetBufferContextHandle(buffer_id, arrow); } else { + table_.erase(iter); LOG_ERROR("Failed to acquire parquet buffer: %s", buffer_id.to_string().c_str()); return ParquetBufferContextHandle(); @@ -154,25 +161,14 @@ ParquetBufferContextHandle ParquetBufferPool::acquire_buffer( } std::shared_ptr ParquetBufferPool::set_block_acquired( - ParquetBufferID buffer_id) { - ParquetBufferContext &context = table_[buffer_id]; - while (true) { - int current_count = context.ref_count.load(std::memory_order_relaxed); - if (current_count >= 0) { - if (context.ref_count.compare_exchange_weak( - current_count, current_count + 1, std::memory_order_acq_rel, - std::memory_order_acquire)) { - return context.arrow; - } - } else { - if (context.ref_count.compare_exchange_weak(current_count, 1, - std::memory_order_acq_rel, - std::memory_order_acquire)) { - context.load_count.fetch_add(1, std::memory_order_relaxed); - return context.arrow; - } - } + ParquetBufferContext &context) { + int current_count = context.ref_count.load(std::memory_order_relaxed); + if (current_count <= 0) { + context.load_count.fetch_add(1, std::memory_order_relaxed); + current_count = 0; } + context.ref_count.store(current_count + 1, std::memory_order_release); + return context.arrow; } std::shared_ptr ParquetBufferPool::acquire( @@ -181,7 +177,7 @@ std::shared_ptr ParquetBufferPool::acquire( if (iter == table_.end()) { return nullptr; } - ParquetBufferContext &context = table_[buffer_id]; + ParquetBufferContext &context = iter->second; while (true) { int current_count = context.ref_count.load(std::memory_order_acquire); if (current_count < 0) { @@ -196,7 +192,6 @@ std::shared_ptr ParquetBufferPool::acquire( return context.arrow; } } - return nullptr; } std::shared_ptr ParquetBufferPool::acquire_locked( @@ -211,7 +206,7 @@ void ParquetBufferPool::release(ParquetBufferID buffer_id) { if (iter == table_.end()) { return; } - ParquetBufferContext &context = table_[buffer_id]; + ParquetBufferContext &context = iter->second; if (context.ref_count.fetch_sub(1, std::memory_order_release) == 1) { std::atomic_thread_fence(std::memory_order_acquire); BlockEvictionQueue::BlockType block; @@ -227,13 +222,14 @@ void ParquetBufferPool::evict(ParquetBufferID buffer_id) { if (iter == table_.end()) { return; } - ParquetBufferContext &context = table_[buffer_id]; + ParquetBufferContext &context = iter->second; int expected = 0; if (context.ref_count.compare_exchange_strong( expected, std::numeric_limits::min())) { MemoryLimitPool::get_instance().release_parquet(context.size); context.arrow = nullptr; context.arrow_refs.clear(); + table_.erase(iter); } } diff --git a/src/include/zvec/ailego/buffer/parquet_hash_table.h b/src/include/zvec/ailego/buffer/parquet_hash_table.h index 4db1a8f3d..bc99b2c3e 100644 --- a/src/include/zvec/ailego/buffer/parquet_hash_table.h +++ b/src/include/zvec/ailego/buffer/parquet_hash_table.h @@ -105,6 +105,8 @@ class ParquetBufferContextHandle { }; class ParquetBufferPool { + friend class ParquetBufferContextHandle; + public: typedef std::shared_ptr Pointer; @@ -123,21 +125,8 @@ class ParquetBufferPool { using Table = std::unordered_map; - arrow::Status acquire(ParquetBufferID buffer_id, - ParquetBufferContext &context); - ParquetBufferContextHandle acquire_buffer(ParquetBufferID buffer_id); - std::shared_ptr set_block_acquired( - ParquetBufferID buffer_id); - - std::shared_ptr acquire(ParquetBufferID buffer_id); - - std::shared_ptr acquire_locked( - ParquetBufferID buffer_id); - - void release(ParquetBufferID buffer_id); - void evict(ParquetBufferID buffer_id); bool is_dead_node(BlockEvictionQueue::BlockType &block); @@ -155,6 +144,19 @@ class ParquetBufferPool { private: ParquetBufferPool() = default; + std::shared_ptr acquire_locked( + ParquetBufferID buffer_id); + + void release(ParquetBufferID buffer_id); + + arrow::Status acquire(ParquetBufferID buffer_id, + ParquetBufferContext &context); + + std::shared_ptr acquire(ParquetBufferID buffer_id); + + std::shared_ptr set_block_acquired( + ParquetBufferContext &context); + private: Table table_; std::shared_mutex table_mutex_;