Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 94 additions & 113 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "ensemble_scheduler.h"

#include <chrono>
#include <condition_variable>
#include <mutex>

Expand Down Expand Up @@ -152,82 +153,6 @@ class RequestTracker {
triton::common::ThreadPool* const callback_pool_;
};

// Limits concurrent inflight requests for a single ensemble step.
// Tracks inflight requests count and blocks producers when limit is reached.
class StepInflightRequestLimiter {
public:
explicit StepInflightRequestLimiter(const size_t max_inflight)
: inflight_count_(0), max_inflight_(max_inflight)
{
}

// Wait until capacity is available or request is cancelled.
// No-op if limit not configured (max_inflight_ == 0).
void WaitForCapacity(
RequestTracker* request_tracker, const size_t step_idx,
const std::string& ensemble_name)
{
// No limit configured, no blocking
if (max_inflight_ == 0) {
return;
}

std::unique_lock<std::mutex> lk(mutex_);
auto timeout = std::chrono::seconds(kMutexTimeoutSeconds);

auto is_request_cancelled = [&]() {
auto& req = request_tracker->Request();
return (req == nullptr) || req->IsCancelled();
};

bool capacity_available = cv_.wait_for(lk, timeout, [&] {
return is_request_cancelled() || (inflight_count_ < max_inflight_);
});

// Log error if timeout occurred (not cancellation), but proceed anyway
// to avoid deadlock. Caller always continues after this call.
if (!capacity_available && !is_request_cancelled()) {
LOG_ERROR << "[Internal Error] Ensemble '" << ensemble_name
<< "' unable to schedule step " << step_idx
<< " (inflight: " << inflight_count_
<< " >= limit: " << max_inflight_ << ") for "
<< kMutexTimeoutSeconds
<< " seconds. Proceeding to avoid deadlock.";
}
}

// Increment inflight count after successfully scheduling a request.
// No-op if limit not configured (max_inflight_ == 0).
void IncrementInflightCount()
{
// No limit configured, no tracking needed
if (max_inflight_ == 0) {
return;
}
std::lock_guard<std::mutex> lk(mutex_);
inflight_count_++;
}

// Decrement inflight count when a request completes, and notify waiting
// producers. No-op if limit not configured (max_inflight_ == 0).
void DecrementInflightCount()
{
// No limit configured, no tracking needed
if (max_inflight_ == 0) {
return;
}
std::lock_guard<std::mutex> lk(mutex_);
inflight_count_--;
cv_.notify_one();
}

private:
size_t inflight_count_;
const size_t max_inflight_;
std::mutex mutex_;
std::condition_variable cv_;
};

// Step is used as 'userp' and keeps ensemble context alive
// until no more internal requests are inflight.
// Step contains metadata, and status for the
Expand Down Expand Up @@ -448,11 +373,6 @@ class EnsembleContext {

size_t inflight_step_counter_;

// Inflight request limiters for each ensemble step.
// Only allocated when max_inflight_requests_ > 0.
std::vector<std::unique_ptr<StepInflightRequestLimiter>>
step_inflight_request_limiters_;

// pointer that either points to 'pruned_tensor_to_step_' or to
// 'info_->tensor_to_step_' if all ensemble outputs are requested
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
Expand Down Expand Up @@ -592,17 +512,6 @@ EnsembleContext::EnsembleContext(
}
}

// Initialize step inflight request limiters for each step.
if (info_->max_inflight_requests_ > 0) {
size_t num_steps = info_->steps_.size();
step_inflight_request_limiters_.reserve(num_steps);
for (size_t i = 0; i < num_steps; i++) {
step_inflight_request_limiters_.emplace_back(
std::make_unique<StepInflightRequestLimiter>(
info_->max_inflight_requests_));
}
}

if (ensemble_status_.IsOk()) {
request_id_ = lrequest->Id();
correlation_id_ = lrequest->CorrelationId();
Expand Down Expand Up @@ -1016,9 +925,9 @@ EnsembleContext::UpdateEnsembleState(
if (completed_step->response_flags_ &
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
inflight_step_counter_--;
if (!step_inflight_request_limiters_.empty()) {
step_inflight_request_limiters_[completed_step->step_idx_]
->DecrementInflightCount();
if (!info_->step_inflight_request_limiters_.empty()) {
info_->step_inflight_request_limiters_[completed_step->step_idx_]
->Release();
}
}
RETURN_IF_ERROR(ConsumeResponse(completed_step));
Expand Down Expand Up @@ -1510,13 +1419,6 @@ EnsembleContext::ScheduleSteps(
step->ctx_ = context;
size_t this_step_idx = step->step_idx_;

// Apply step inflight request limiters if configured.
if (!context->step_inflight_request_limiters_.empty()) {
context->step_inflight_request_limiters_[this_step_idx]->WaitForCapacity(
context->request_tracker_, this_step_idx,
context->info_->ensemble_name_);
}

bool should_schedule = false;
// Must release lock before InferAsync to avoid deadlock, as the same thread
// will be calling request/response callbacks on cache hits, which will
Expand All @@ -1537,6 +1439,14 @@ EnsembleContext::ScheduleSteps(
if (context->request_tracker_->Request()->IsCancelled()) {
step->request_->Cancel();
}
// Acquire a slot from the per-step shared limiter only for steps that
// will be dispatched, so that a failed ensemble does not hold capacity
// unnecessarily.
if (!context->info_->step_inflight_request_limiters_.empty()) {
context->info_->step_inflight_request_limiters_[this_step_idx]->Acquire(
context->request_tracker_->Request(), this_step_idx,
context->info_->ensemble_name_);
}
// On a successful call to InferAsync(), the step will be released by
// the response callback. When the response callback is invoked, the
// step must not own (and release) the request as the request should be
Expand All @@ -1546,13 +1456,6 @@ EnsembleContext::ScheduleSteps(
std::unique_ptr<InferenceRequest> request = std::move(step->request_);
auto step_status = context->is_->InferAsync(request);
if (step_status.IsOk()) {
// Increment inflight counter AFTER successful scheduling. Always
// increment for ALL steps (including step 0) to ensure symmetry with
// decrement and prevent underflow when steps complete.
if (!context->step_inflight_request_limiters_.empty()) {
context->step_inflight_request_limiters_[this_step_idx]
->IncrementInflightCount();
}
step.release();
continue;
} else {
Expand All @@ -1565,10 +1468,20 @@ EnsembleContext::ScheduleSteps(

// Reaching here means the step is not being scheduled, update corresponding
// counters and attempt to finish ensemble if it is the last step.


// Release the limiter slot if one was acquired, and update counters.
if (should_schedule &&
!context->info_->step_inflight_request_limiters_.empty()) {
context->info_->step_inflight_request_limiters_[this_step_idx]->Release();
}

std::lock_guard<std::mutex> lock(context->mutex_);
// The request is not sent to server properly, shouldn't expect its
// release function get called.
context->request_tracker_->DecrementCounter();
// Decrement only when IncrementCounter was called. An unconditional
// decrement would underflow the counter and cause a use-after-free.
if (should_schedule) {
context->request_tracker_->DecrementCounter();
}
--context->inflight_step_counter_;

if (context->inflight_step_counter_ == 0) {
Expand All @@ -1579,6 +1492,64 @@ EnsembleContext::ScheduleSteps(

} // namespace

StepInflightRequestLimiter::StepInflightRequestLimiter(
const size_t max_inflight)
: inflight_count_(0), max_inflight_(max_inflight)
{
}

void
StepInflightRequestLimiter::Acquire(
const std::unique_ptr<InferenceRequest>& request, const size_t step_idx,
const std::string& ensemble_name)
{
// No limit is configured, so requests are not blocked.
if (max_inflight_ == 0) {
return;
}

std::unique_lock<std::mutex> lk(mutex_);
auto timeout = std::chrono::seconds(kMutexTimeoutSeconds);

auto is_request_cancelled = [&]() {
return (request == nullptr) || request->IsCancelled();
};

bool capacity_available = cv_.wait_for(lk, timeout, [&] {
return is_request_cancelled() || (inflight_count_ < max_inflight_);
});

if (!capacity_available && !is_request_cancelled()) {
Comment thread
pskiran1 marked this conversation as resolved.
LOG_ERROR << "[Internal Error] Ensemble '" << ensemble_name
<< "' unable to schedule step " << step_idx
<< " (inflight: " << inflight_count_
<< " >= limit: " << max_inflight_ << ") for "
<< kMutexTimeoutSeconds
<< " seconds. Proceeding to avoid deadlock.";
}

// Increment while holding the lock to prevent transient oversubscription.
inflight_count_++;
Comment thread
pskiran1 marked this conversation as resolved.
}

void
StepInflightRequestLimiter::Release()
{
// No limit is configured, so requests are not blocked.
if (max_inflight_ == 0) {
return;
}

std::lock_guard<std::mutex> lk(mutex_);
if (inflight_count_ == 0) {
LOG_ERROR << "[Internal Error] step inflight request limiter underflow";
return;
}

inflight_count_--;
cv_.notify_one();
}

Status
EnsembleScheduler::Create(
InferenceStatsAggregator* const stats_aggregator,
Expand Down Expand Up @@ -1736,14 +1707,24 @@ EnsembleScheduler::EnsembleScheduler(
}
callback_pool_ = is_->EnsembleCallbackPool();

// Parse the configuration for max_inflight_requests from the protobuf field.
// Parse the max_inflight_requests configuration from the protobuf field
if (config.has_ensemble_scheduling()) {
info_->max_inflight_requests_ =
config.ensemble_scheduling().max_inflight_requests();
if (info_->max_inflight_requests_ > 0) {
LOG_INFO << "Ensemble model '" << config.name()
<< "' configured with max_inflight_requests: "
<< info_->max_inflight_requests_;

// Allocate one limiter per step to ensure max_inflight_requests is
// enforced as a shared limit across all concurrent requests for this
// ensemble model.
info_->step_inflight_request_limiters_.reserve(info_->steps_.size());
for (size_t i = 0; i < info_->steps_.size(); ++i) {
info_->step_inflight_request_limiters_.emplace_back(
std::make_unique<StepInflightRequestLimiter>(
info_->max_inflight_requests_));
}
}
}
}
Expand Down
44 changes: 36 additions & 8 deletions src/ensemble_scheduler/ensemble_scheduler.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -50,6 +50,31 @@ using cudaStream_t = void*;

class InferenceServer;

// Enforces a per-step limit on concurrent in-flight requests, shared across
// all active ensemble requests for a given ensemble model. Tracks in-flight
// request count and blocks producers when the limit is reached.
class StepInflightRequestLimiter {
public:
explicit StepInflightRequestLimiter(size_t max_inflight);

// Blocks until a slot is available or the request is cancelled. Cancelled
// requests skip the wait so cancellation propagates via the normal
// step-scheduling path. The const reference prevents ownership transfer;
// only IsCancelled() is queried on the pointed-to request.
void Acquire(
const std::unique_ptr<InferenceRequest>& request, size_t step_idx,
const std::string& ensemble_name);

// Releases one acquired slot and wakes one waiting thread.
void Release();

private:
size_t inflight_count_;
const size_t max_inflight_;
std::mutex mutex_;
std::condition_variable cv_;
};

struct EnsembleInfo {
struct StepInfo {
StepInfo(const ModelIdentifier& model_id, const int64_t model_version)
Expand Down Expand Up @@ -84,14 +109,17 @@ struct EnsembleInfo {
// backward path, ensemble tensor to the step that provides its data
std::unordered_map<std::string, size_t> tensor_to_prev_step_;

// The maximum number of concurrent inflight requests allowed at each ensemble
// step per inference request. This limit is applied per step and per
// inference request, not globally for the entire ensemble model. This limit
// prevents unbounded memory growth when ensemble steps produce responses
// faster than downstream steps can consume them. Default value is 0, which
// indicates that no limit is enforced. Configured via 'max_inflight_requests'
// field in ensemble_scheduling.
// The maximum number of concurrent in-flight requests allowed at each
// ensemble step across all concurrent ensemble requests for this model.
// The limit is applied per step index and is shared across all concurrent
// requests for this ensemble model.
// This limit prevents unbounded memory growth when upstream steps
// produce responses faster than downstream steps can consume them.
// A value of 0 means no limit is enforced.
// Configured via the 'max_inflight_requests' field in ensemble_scheduling.
size_t max_inflight_requests_ = 0;
std::vector<std::unique_ptr<StepInflightRequestLimiter>>
step_inflight_request_limiters_;
Comment thread
pskiran1 marked this conversation as resolved.
};

// Scheduler that implements ensemble scheduling.
Expand Down
Loading