Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 88 additions & 38 deletions deepconf/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def deepthink(
sampling_params: Optional[SamplingParams] = None,
# Multiple voting options
compute_multiple_voting: bool = True,
# Online mode specific (adaptive sampling)
consensus_threshold: float = 0.95,
adaptive_step_size: int = 1,
**kwargs
) -> DeepThinkOutput:
"""
Expand All @@ -96,6 +99,8 @@ def deepthink(
window_size: Window size for confidence computation
sampling_params: Custom vLLM sampling parameters
compute_multiple_voting: Whether to compute multiple voting method results
consensus_threshold: Confidence threshold for consensus in online mode
adaptive_step_size: Step size for adaptive sampling in online mode

Returns:
DeepThinkOutput containing results
Expand Down Expand Up @@ -129,11 +134,13 @@ def deepthink(
"warmup_traces": warmup_traces,
"total_budget": total_budget,
"confidence_percentile": confidence_percentile,
"consensus_threshold": consensus_threshold,
"adaptive_step_size": adaptive_step_size,
})
result = self._deepthink_online(
prompt, output,
warmup_traces, total_budget, confidence_percentile,
window_size, sampling_params
window_size, sampling_params, consensus_threshold, adaptive_step_size
)
else:
output.config.update({
Expand Down Expand Up @@ -177,7 +184,9 @@ def _deepthink_online(
total_budget: int,
confidence_percentile: int,
window_size: int,
sampling_params: Optional[SamplingParams]
sampling_params: Optional[SamplingParams],
consensus_threshold: float,
adaptive_step_size: int
) -> DeepThinkOutput:
"""Online deep thinking with confidence-based early stopping"""

Expand Down Expand Up @@ -212,51 +221,92 @@ def _deepthink_online(
output.warmup_tokens = warmup_result['total_tokens']

print(f"Warmup completed: conf_bar={output.conf_bar:.3f}")

# Final phase
print(f"Starting final phase...", sampling_params)
print(f"Starting adaptive final phase (step={adaptive_step_size}, threshold={consensus_threshold})...")
final_gen_start = time.time()
# final_params = copy.deepcopy(sampling_params)
# final_params.seed = int(time.time())
# final_params.n = total_budget - warmup_traces

final_params_list = []
for param_id in range(total_budget - warmup_traces):
final_params = copy.deepcopy(sampling_params)
final_params.logprobs = 20
final_params.seed = base_seed + param_id + warmup_traces
final_params.extra_args = {
"conf_threshold": output.conf_bar,
"eos_token_id": self.tokenizer.eos_token_id,
"conf_group_size": window_size,
"conf_topk": 20,
}
final_params_list.append(final_params)
final_outputs = self.llm.generate([prompt for _ in range(total_budget - warmup_traces)], final_params_list)
output.final_gen_time = time.time() - final_gen_start

# Process final results
final_process_start = time.time()
final_result = process_batch_results(final_outputs, window_size)
output.final_process_time = time.time() - final_process_start

print('Final min_confs:', final_result['min_confs'])
output.final_min_confs = final_result['min_confs']
# Initialize containers
output.final_traces = []
output.final_tokens = 0
output.final_min_confs = []

output.final_traces = final_result['traces']
output.final_tokens = final_result['total_tokens']
# Pool for consensus checking (starts with warmup traces)
current_traces = output.warmup_traces[:]
remaining_budget = total_budget - warmup_traces
base_seed = time.time_ns()

# Apply confidence threshold to final traces
for trace in output.final_traces:
if trace["min_conf"] < output.conf_bar:
trace["stop_reason"] = "gconf_threshold"
# Adaptive Sampling Loop (Algorithm 2)
while remaining_budget > 0:
# Consensus Check (Before generating new traces)
valid_answers = []
weights = []

for trace in current_traces:
# Filter: Must have answer AND not be stopped by confidence filter
if trace.get('extracted_answer') and trace.get('stop_reason') != 'gconf_threshold':
valid_answers.append(trace['extracted_answer'])
weights.append(trace.get('min_conf', 0))

if valid_answers:
# Weighted Majority Vote
winner = weighted_majority_vote(valid_answers, weights)

if winner:
total_weight = sum(weights)
winner_weight = sum(w for a, w in zip(valid_answers, weights) if a == winner)
beta = winner_weight / total_weight if total_weight > 0 else 0.0

# Stop if consensus reached
if beta >= consensus_threshold:
print(f"Consensus reached (beta={beta:.3f}). Stopping early.")
break

# Prepare Next Step
# Use step_size=1 to strictly follow paper, or higher for speed
step = min(adaptive_step_size, remaining_budget)
current_params_list = []

for i in range(step):
# Ensure unique seeds
current_seed = base_seed + (total_budget - remaining_budget) + i
p = copy.deepcopy(sampling_params)
p.logprobs = 20
p.seed = current_seed
p.extra_args = {
"conf_threshold": output.conf_bar,
"eos_token_id": self.tokenizer.eos_token_id,
"conf_group_size": window_size,
"conf_topk": 20,
}
current_params_list.append(p)

# Generate
outputs = self.llm.generate([prompt] * step, current_params_list)

# Process
batch_proc_start = time.time()
batch_res = process_batch_results(outputs, window_size)
output.final_process_time += (time.time() - batch_proc_start)

output.final_traces.extend(batch_res['traces'])
output.final_min_confs.extend(batch_res['min_confs'])
output.final_tokens += batch_res['total_tokens']

# Update trace pool (flagging early stops)
for trace in batch_res['traces']:
if trace["min_conf"] < output.conf_bar:
trace["stop_reason"] = "gconf_threshold"
current_traces.append(trace)

remaining_budget -= step

output.final_gen_time = time.time() - final_gen_start

# Combine all traces
# Combine results
output.all_traces = output.warmup_traces + output.final_traces
output.total_tokens = output.warmup_tokens + output.final_tokens
output.total_traces_count = len(output.all_traces)

# Basic voting (for backward compatibility)
# Basic voting (Backward compatibility)
self._perform_basic_voting(output)

output.processing_time = time.time() - processing_start
Expand Down