This document consolidates everything you need to use DeepConf on vLLM:
- How to modify vLLM to implement DeepConf - A minimal patch guide showing what to change, where, and why.
- Or skip the modifications - Directly download our PR with pre-implemented changes.
- Example code to run DeepConf - Complete working examples demonstrating the feature.
Option 1: Directly download our PR
- Related PR: https://github.com/vllm-project/vllm/pull/23201
Option 2: Modify vLLM
This README documents a minimal set of edits to integrate DeepConf (confidence-based early stopping) into vLLM, how to enable it via the OpenAI-compatible API, and pointers to the example notebook.
Tested Environment
- vLLM: commit
31f09c615f4f067dba765ce5fe7d00d880212a6d - Python: 3.12.0
- CUDA: 12.8
High-Level Changes
We modify only two places in vLLM:
- Extend
LogprobsProcessorto maintain a sliding-window confidence and exposecheck_conf_stop(). - In
output_processor.py, insert a single early-stop check before constructingRequestOutput.
Enabling via OpenAI-Compatible API
The feature is toggled per request via the OpenAI-compatible chat.completions endpoint. The arguments are passed through extra_body["vllm_xargs"] and forwarded by vLLM to SamplingParams.extra_args.
# Code: Enable confidence-based early stopping via OpenAI-compatible API
responses = client.chat.completions.create(
model=args.model_path,
messages=messages,
max_tokens=args.max_tokens,
temperature=0.6,
top_p=0.95,
logprobs=True,
top_logprobs=20, # request candidate logprobs (>=2)
n=real_gen,
extra_body={
"top_k": 0,
"vllm_xargs": {
"enable_conf": True,
"window_size": 2048,
"threshold": conf_threshold
}
}
)
- The early-stop logic is inactive unless
logprobs=Trueandtop_logprobs>=2. window_sizeis the confidence window length;thresholdis the cutoff used by our method.top_k=0(optional) disables top-k truncation.
Exact Edits (Copy-Paste Guidance)
No patch tools are required; copy the snippets below into the indicated files. We recommend pinning to the commit above to avoid API drift.
File: vllm/v1/engine/logprobs.py
Step 1: Import (near the top):
from collections import deque from typing import Optional, List
Step 2: Extend the dataclass (class LogprobsProcessor):
# --- fields for confidence-based early stopping --- conf_grouped: float conf_list: Optional[List[float]] conf_group_list: Optional[deque] conf_group_size: int conf_threshold: Optional[float]
Step 3: Initialize from the request (inside from_new_request(...), right before return cls(...)):
if hasattr(request.sampling_params, "extra_args") \
and request.sampling_params.extra_args is not None \
and request.sampling_params.extra_args.get("enable_conf", False):
conf_group_size = request.sampling_params.extra_args.get("window_size", 2048)
conf_threshold = request.sampling_params.extra_args.get("threshold", 17)
conf_grouped = 0.0
conf_group_list = deque(maxlen=conf_group_size)
conf_list = []
else:
conf_group_size = -1
conf_threshold = None
conf_grouped = 0.0
conf_group_list = None
conf_list = None
Then include the fields below in the return cls(...) call:
conf_group_size=conf_group_size, conf_grouped=conf_grouped, conf_list=conf_list, conf_threshold=conf_threshold, conf_group_list=conf_group_list,
Step 4: Stop-check helper (add this method inside the class):
def check_conf_stop(self) -> bool:
"""Return True if the confidence window triggers early stopping."""
if self.conf_group_list is None or len(self.conf_group_list) == 0:
return False
# Require a full window; trigger when the moving average is below threshold.
return (len(self.conf_group_list) >= self.conf_group_size
and self.conf_grouped / len(self.conf_group_list) < self.conf_threshold)
Step 5: Update confidence during sampling (at the end of _update_sample_logprobs(...), after appending the logprob dict):
if self.conf_list is not None:
# logprobs[0] is the sampled token; use the remaining candidates
if len(logprobs) > 1:
new_conf = -sum(logprobs[1:]) / len(logprobs[1:])
else:
new_conf = 0.0
self.conf_list.append(new_conf)
if len(self.conf_group_list) < self.conf_group_size:
self.conf_group_list.append(new_conf)
self.conf_grouped += new_conf
else:
self.conf_grouped -= self.conf_group_list.popleft()
self.conf_group_list.append(new_conf)
self.conf_grouped += new_conf
File: vllm/v1/engine/output_processor.py
Step 6: Invoke the stop-check in the decode loop.
Immediately after:
req_state.logprobs_processor.update_from_output(engine_core_output)
insert:
# Confidence-based early stopping (ours)
if req_state.logprobs_processor.check_conf_stop():
finish_reason = FinishReason.STOP
stop_reason = f"<gconf<{req_state.logprobs_processor.conf_threshold}>"
(Leave the subsequent logic that builds RequestOutput unchanged.)
Additional Notes
- The feature is inactive unless
enable_conf=Trueandlogprobs>0(we usetop_logprobs=20). - Confidence is the moving average of the negative mean candidate logprobs over a fixed window (
window_size). - When triggered, we set
FinishReason.STOPand annotatestop_reasonwith<gconf<THR>>for traceability.
Go to this PR and download the vllm version with the above changes.
Running Example
Install vLLM
Download the vllm and modify the vllm code as above, or download the version in the pr https://github.com/vllm-project/vllm/pull/23201 and build it as follows.
VLLM_USE_PRECOMPILED=1 uv pip install --editable .
Install dependencies for the example code
git clone https://github.com/hao-ai-lab/Dynasor.git cd Dynasor && pip install . && cd -
Example 1: Offline Generation
python
import openai
import json
from tqdm import tqdm
import time
import os
import requests
from datetime import datetime
from transformers import AutoTokenizer
import concurrent.futures
import threading
from functools import partial
# ===========================
# Model Configurations
# ===========================
MODEL_CONFIGS = {
"Qwen/Qwen3-8B": {
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"max_tokens": 32000,
"template": "qwen3"
},
"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B": {
"temperature": 0.6,
"top_p": 0.95,
"top_k": 0,
"max_tokens": 64000,
"template": "dpsk_qwen_0528"
},
"openai/gpt-oss-20b": {
"temperature": 1.0,
"top_p": 1.0,
"top_k": 40,
"max_tokens": 14000,
"template": "gpt"
},
# Add more model configurations as needed
}
# ===========================
# Main Configuration
# ===========================
# Select your model
MODEL_NAME = "Qwen/Qwen3-8B" # Change this to your desired model
SAMPLES_PER_QUESTION = 4 # Number of traces to generate per question
DATASET_FILE = "aime25.jsonl" # Input dataset file
REASONING_EFFORT = "high" # For GPT models: low, medium, high
# Parallel processing configuration
MAX_WORKERS = 8 # Maximum number of concurrent workers (adjust based on your server capacity)
MAX_WORKERS_PER_QUESTION = 4 # Maximum workers for traces within a single question
# Get model-specific config
model_config = MODEL_CONFIGS.get(MODEL_NAME)
# General Configuration
CONFIG = {
"model_path": MODEL_NAME,
"server_port": 8000,
"temperature": model_config["temperature"],
"top_p": model_config["top_p"],
"top_k": model_config["top_k"],
"max_tokens": model_config["max_tokens"],
"template": model_config["template"],
"reasoning_effort": REASONING_EFFORT,
# Dataset and sampling configuration
"dataset": DATASET_FILE, # Input dataset file
"max_samples_per_question": SAMPLES_PER_QUESTION, # Number of traces per question
"output_dir": f"output_{MODEL_NAME.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
# Parallel processing configuration
"max_workers": MAX_WORKERS,
"max_workers_per_question": MAX_WORKERS_PER_QUESTION,
}
# Thread-safe file writing lock
file_lock = threading.Lock()
# ===========================
# Initialize OpenAI Client
# ===========================
# Note: Make sure vLLM server is already running on the specified port
# Example command to start vLLM server:
# vllm serve MODEL_NAME --port 8000 -tp 1 --gpu-memory-utilization 0.7 --enable-prefix-caching
print(f"Connecting to vLLM server...")
print(f"Model: {CONFIG['model_path']}")
print(f"Server URL: http://localhost:{CONFIG['server_port']}/v1")
print(f"Max concurrent workers: {CONFIG['max_workers']}")
print(f"Max workers per question: {CONFIG['max_workers_per_question']}")
# Initialize OpenAI client
client = openai.OpenAI(
api_key="None",
base_url=f"http://localhost:{CONFIG['server_port']}/v1",
timeout=None
)
# Initialize tokenizer for GPT models
if CONFIG['template'] == 'gpt':
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_path'])
else:
tokenizer = None
# Test connection
try:
response = requests.get(
f"http://localhost:{CONFIG['server_port']}/v1/models",
headers={"Authorization": "Bearer None"},
)
if response.status_code == 200:
print("✅ Successfully connected to vLLM server")
else:
print(f"⚠️ Server returned status code: {response.status_code}")
except requests.exceptions.RequestException as e:
print(f"❌ Failed to connect to vLLM server: {e}")
print("Please ensure vLLM server is running on the specified port")
# ===========================
# Core Processing Functions
# ===========================
def get_gpt_token_probabilities(messages, max_tokens=50):
"""
Function to get token probabilities for GPT models using completions API
"""
response = client.completions.create(
model=CONFIG['model_path'],
prompt=messages,
max_tokens=max_tokens,
temperature=CONFIG['temperature'],
top_p=CONFIG['top_p'],
logprobs=20,
extra_body={
"top_k": CONFIG['top_k']
},
)
# Extract generated text
generated_text = response.choices[0].text
# Extract token probabilities
token_probs = []
mean_confs = []
tokens = []
log_probs = []
if response.choices[0].logprobs and response.choices[0].logprobs.tokens:
for i, token_data in enumerate(response.choices[0].logprobs.tokens):
step_probs = {
"s": i, # step -> s
"t": response.choices[0].logprobs.tokens[i], # generated_token -> t
"lp": round(response.choices[0].logprobs.token_logprobs[i], 2), # logprob of generated token
"a": [] # top_20_tokens -> a (alternatives)
}
# Add only top 5 alternatives to save space
if response.choices[0].logprobs.top_logprobs:
for tok, value in response.choices[0].logprobs.top_logprobs[i].items(): # Only top 5
step_probs["a"].append([
tok,
round(value, 2)
]) # Use array instead of dict
token_probs.append(step_probs)
if step_probs['a']:
mean_confs.append(round(-sum(p[1] for p in step_probs['a']) / len(step_probs['a']), 2))
else:
mean_confs.append(0)
tokens.append(response.choices[0].logprobs.tokens[i])
log_probs.append(round(response.choices[0].logprobs.token_logprobs[i], 2))
return {
"text": generated_text,
"probs": token_probs, # token_probabilities -> probs,
"mean_confidences": mean_confs, # mean_confidences -> mean_confs
"tokens": tokens,
"log_probs": log_probs # log_probs -> log_probs
}
def get_token_probabilities(prompt, messages):
"""Get token probabilities from the vLLM server using chat completions API."""
response = client.chat.completions.create(
model=CONFIG['model_path'],
messages=messages,
max_tokens=CONFIG['max_tokens'],
temperature=CONFIG['temperature'],
top_p=CONFIG['top_p'],
logprobs=True,
top_logprobs=20,
extra_body={"top_k": CONFIG['top_k']},
)
generated_text = response.choices[0].message.content
token_probs = []
mean_confs = []
tokens = []
log_probs = []
if response.choices[0].logprobs and response.choices[0].logprobs.content:
for i, token_data in enumerate(response.choices[0].logprobs.content):
step_probs = {
"s": i,
"t": token_data.token,
"lp": round(token_data.logprob, 2),
"a": []
}
if token_data.top_logprobs:
for logprob_data in token_data.top_logprobs[:5]: # Top 5 alternatives
step_probs["a"].append([
logprob_data.token,
round(logprob_data.logprob, 2)
])
token_probs.append(step_probs)
if step_probs['a']:
mean_confs.append(round(-sum(p[1] for p in step_probs['a']) / len(step_probs['a']), 2))
else:
mean_confs.append(0)
tokens.append(token_data.token)
log_probs.append(round(token_data.logprob, 2))
return {
"text": generated_text,
"probs": token_probs,
"mean_confidences": mean_confs,
"tokens": tokens,
"log_probs": log_probs
}
def prepare_messages(prompt, template):
"""Prepare messages based on template."""
if template == "dpsk_qwen_0528":
return [
{"role": "system", "content": "该助手为DeepSeek-R1,由深度求索公司创造。\n今天是2025年5月28日,星期一。\n"},
{"role": "user", "content": prompt}
]
elif template == 'qwen3':
return [
{"role": "user", "content": prompt + "\nPlease reason step by step, and put your final answer within \\boxed{}."}
]
elif template == 'gpt':
# For GPT models, we'll prepare a simple string message first
return prompt + "\nPlease reason step by step, and put your final answer within \\boxed{{}}."
else:
return [{"role": "user", "content": prompt}]
def generate_single_trace(question_meta, trace_idx, output_dir):
"""Generate a single trace for a question. This function will be run in parallel."""
try:
prompt = question_meta["prompt"]
q_idx = question_meta["question_id"]
messages = prepare_messages(prompt, CONFIG['template'])
# Handle GPT models differently
if CONFIG['template'] == 'gpt':
# Apply chat template with reasoning effort for GPT models
if tokenizer:
formatted_messages = tokenizer.apply_chat_template(
conversation=[
{"role": "user", "content": messages}
],
add_generation_prompt=True,
reasoning_effort=CONFIG['reasoning_effort'],
tokenize=False,
)
result = get_gpt_token_probabilities(messages=formatted_messages, max_tokens=CONFIG['max_tokens'])
else:
# Fallback if tokenizer is not available
result = get_gpt_token_probabilities(messages=messages, max_tokens=CONFIG['max_tokens'])
else:
# Use chat completions for other models
result = get_token_probabilities(prompt, messages)
# Prepare trace data
trace_data_processed = {
"question_meta": question_meta,
"trace_id": trace_idx,
"response": result["text"],
"tokens": result["tokens"],
"mean_confidences": result["mean_confidences"],
"log_probs": result["log_probs"],
"messages": messages,
}
# Thread-safe file writing
processed_file = os.path.join(output_dir, f"{q_idx}_processed.jsonl")
with file_lock:
with open(processed_file, "a", encoding="utf-8") as f:
f.write(json.dumps(trace_data_processed, ensure_ascii=False) + "\n")
return True, None
except Exception as e:
return False, f"Error in question {question_meta['question_id']}, trace {trace_idx}: {e}"
def process_question_parallel(question, q_idx, output_dir):
"""Process a single question and generate multiple traces in parallel."""
prompt = question.get("problem", question.get("question", question.get("prompt", "")))
if not prompt:
print(f"Warning: No prompt found in question {q_idx}")
return 0
question_meta = {
"question_id": q_idx,
"original_question": question,
"prompt": prompt,
}
# Generate traces in parallel
completed_traces = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=CONFIG['max_workers_per_question']) as executor:
# Submit all trace generation tasks
future_to_trace = {
executor.submit(generate_single_trace, question_meta, trace_idx, output_dir): trace_idx
for trace_idx in range(CONFIG['max_samples_per_question'])
}
# Collect results as they complete
for future in concurrent.futures.as_completed(future_to_trace):
trace_idx = future_to_trace[future]
try:
success, error_msg = future.result()
if success:
completed_traces += 1
else:
print(f"Error: {error_msg}")
except Exception as e:
print(f"Exception in trace {trace_idx}: {e}")
return completed_traces
def process_single_question_wrapper(args):
"""Wrapper function for processing a single question (needed for parallel execution)."""
question, q_idx, output_dir = args
return q_idx, process_question_parallel(question, q_idx, output_dir)
def process_dataset_parallel(dataset_file, output_dir):
"""Process entire dataset with parallel processing."""
os.makedirs(output_dir, exist_ok=True)
print(f"Created output directory: {output_dir}")
# Load dataset
questions = []
try:
with open(dataset_file, "r", encoding="utf-8") as f:
for line in f:
questions.append(json.loads(line.strip()))
print(f"Loaded {len(questions)} questions from {dataset_file}")
except FileNotFoundError:
print(f"Error: {dataset_file} not found!")
return None
# Process questions in parallel
all_results = []
total_traces = 0
# Prepare arguments for parallel processing
question_args = [(question, q_idx, output_dir) for q_idx, question in enumerate(questions)]
print(f"Processing {len(questions)} questions with up to {CONFIG['max_workers']} parallel workers...")
with concurrent.futures.ThreadPoolExecutor(max_workers=CONFIG['max_workers']) as executor:
# Submit all question processing tasks
future_to_question = {
executor.submit(process_single_question_wrapper, args): args[1]
for args in question_args
}
# Use tqdm to track progress
with tqdm(total=len(questions), desc="Processing questions") as pbar:
for future in concurrent.futures.as_completed(future_to_question):
q_idx = future_to_question[future]
try:
result_q_idx, traces_completed = future.result()
total_traces += traces_completed
all_results.append({
"question_id": result_q_idx,
"total_traces": traces_completed,
"file_path": os.path.join(output_dir, f"{result_q_idx}.jsonl")
})
pbar.update(1)
pbar.set_postfix({
'completed_traces': total_traces,
'avg_traces': f"{total_traces / len(all_results):.1f}" if all_results else "0"
})
except Exception as e:
print(f"Exception processing question {q_idx}: {e}")
pbar.update(1)
# Save summary
summary_file = os.path.join(output_dir, "summary.json")
summary = {
"model": CONFIG['model_path'],
"model_config": model_config,
"dataset_file": dataset_file,
"total_questions": len(questions),
"completed_questions": len(all_results),
"total_traces": total_traces,
"average_traces_per_question": total_traces / len(all_results) if all_results else 0,
"output_directory": output_dir,
"timestamp": datetime.now().isoformat(),
"reasoning_effort": CONFIG.get('reasoning_effort', 'N/A'),
"parallel_config": {
"max_workers": CONFIG['max_workers'],
"max_workers_per_question": CONFIG['max_workers_per_question']
}
}
with open(summary_file, "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print(f"\n✅ Completed! Generated {total_traces} total traces")
print(f"📁 Results saved to: {output_dir}")
print(f"📊 Summary: {summary_file}")
print(f"📈 Average traces per question: {total_traces / len(all_results):.1f}")
return output_dir
def check_results(output_dir):
"""Check and display results from output directory."""
if not os.path.exists(output_dir):
print(f"Directory {output_dir} not found!")
return
# Load summary
summary_file = os.path.join(output_dir, "summary.json")
if os.path.exists(summary_file):
with open(summary_file, "r", encoding="utf-8") as f:
summary = json.load(f)
print(f"\nSummary:")
print(f" Model: {summary['model']}")
print(f" Total questions: {summary['total_questions']}")
print(f" Total traces: {summary['total_traces']}")
print(f" Average traces per question: {summary['average_traces_per_question']:.1f}")
print(f" Reasoning effort: {summary.get('reasoning_effort', 'N/A')}")
if 'parallel_config' in summary:
print(f" Max workers: {summary['parallel_config']['max_workers']}")
print(f" Max workers per question: {summary['parallel_config']['max_workers_per_question']}")
# Check individual files
question_files = [f for f in os.listdir(output_dir) if f.endswith('.jsonl')]
question_files.sort(key=lambda x: int(x.split('.')[0].split('_')[0]) if x.split('.')[0].split('_')[0].isdigit() else float('inf'))
print(f"\nFound {len(question_files)} question files")
# Show sample results
for filename in question_files[:3]:
if '_processed' in filename:
continue
filepath = os.path.join(output_dir, filename)
if os.path.exists(filepath):
with open(filepath, "r", encoding="utf-8") as f:
lines = f.readlines()
if lines:
first_trace = json.loads(lines[0].strip())
print(f"\n{filename}:")
print(f" Traces: {len(lines)}")
print(f" First response preview: {first_trace['response'][:150]}...")
# ===========================
# Performance Monitoring
# ===========================
def monitor_performance():
"""Monitor and suggest optimal worker configuration."""
import psutil
cpu_count = psutil.cpu_count()
memory_gb = psutil.virtual_memory().total / (1024**3)
print(f"\n🔧 System Information:")
print(f" CPU cores: {cpu_count}")
print(f" Total memory: {memory_gb:.1f} GB")
# Suggest optimal configuration
suggested_workers = min(cpu_count * 2, 16) # Generally good for I/O bound tasks
suggested_workers_per_q = min(4, suggested_workers // 2)
print(f"\n💡 Suggested Configuration:")
print(f" MAX_WORKERS: {suggested_workers}")
print(f" MAX_WORKERS_PER_QUESTION: {suggested_workers_per_q}")
if CONFIG['max_workers'] > suggested_workers:
print(f"⚠️ Current MAX_WORKERS ({CONFIG['max_workers']}) might be too high for your system")
return suggested_workers, suggested_workers_per_q
# ===========================
# Main Execution
# ===========================
# Monitor system performance
monitor_performance()
print(f"\n🚀 Starting parallel processing:")
print(f" Model: {CONFIG['model_path']}")
print(f" Template: {CONFIG['template']}")
print(f" Dataset: {CONFIG['dataset']}")
print(f" Traces per question: {CONFIG['max_samples_per_question']}")
print(f" Max workers: {CONFIG['max_workers']}")
print(f" Max workers per question: {CONFIG['max_workers_per_question']}")
print(f" Output directory: {CONFIG['output_dir']}")
if CONFIG['template'] == 'gpt':
print(f" Reasoning effort: {CONFIG['reasoning_effort']}")
start_time = time.time()
# Process the dataset with parallel processing
output_dir = process_dataset_parallel(
dataset_file=CONFIG['dataset'],
output_dir=CONFIG['output_dir']
)
end_time = time.time()
processing_time = end_time - start_time
# Check results
if output_dir:
print("\n" + "="*50)
print("Checking generated results...")
check_results(output_dir)
print(f"\n⚡ Processing completed in {processing_time:.2f} seconds!")
print(f"🚀 Speed improvement with parallel processing!")
print("Note: vLLM server is still running. Stop it manually if needed.")
Example 2: Offline Voting
python
import os
import json
import pickle
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
from tqdm import tqdm
import glob
import random
from typing import Dict, List, Tuple, Optional, Any
from dynasor.core.evaluator import math_equal
TRACE_DIRS = [YOURDIR1, YOURDIR2, ...]
# ============================================
# SECTION 1: Math Evaluation Functions
# ============================================
def parse_func(s):
for f in [parse_latex, parse_expr, latex2sympy]:
try:
return f(s.replace("\\\\", "\\"))
except:
try:
return f(s)
except:
pass
return s
def quick_parse(text):
"""Quick parse to remove LaTeX text formatting"""
if '\\text{' in text and '}' in text:
while '\\text{' in text:
start = text.find('\\text{')
if start == -1:
break
end = text.find('}', start)
if end == -1:
break
content = text[start + 6:end]
text = text[:start] + content + text[end + 1:]
return text
# ============================================
# SECTION 2: Answer Extraction
# ============================================
def extract_answer(text):
"""Extract boxed answer from text"""
if "boxed" in text:
ans = text.split("boxed")[-1]
if len(ans) == 0:
return ""
elif ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split("$")[0].strip()
return a.strip()
return None
# ============================================
# SECTION 3: Confidence Metrics Calculation
# ============================================
def calculate_confidence_stats(conf_list, tokens):
"""Calculate various confidence statistics"""
if not conf_list:
return {}
assert len(conf_list) == len(tokens), "Confidence list and tokens must have same length"
conf_array = np.array(conf_list)
total_tokens = len(conf_array)
stats = {
'mean_confidence': np.mean(conf_array)
}
# First/Last N tokens
for n in [2048]:
if total_tokens >= n:
stats[f'tail_{n}_mean_conf'] = np.mean(conf_array[-n:])
else:
stats[f'tail_{n}_mean_conf'] = np.mean(conf_array)
# First/Last percentage
for ratio in [0.1]:
n_tokens = max(1, int(total_tokens * ratio))
stats[f'tail_{ratio}_mean_conf'] = np.mean(conf_array[-n_tokens:])
# Sliding window metrics
window_sizes = [2048]
bottom_percentages = [0.1, 0.5]
for window_size in window_sizes:
if total_tokens < window_size:
stats[f'min_sliding_{window_size}_mean_conf'] = np.mean(conf_array)
for percent in bottom_percentages:
stats[f'bottom_{percent}_sliding_{window_size}_mean_conf'] = np.mean(conf_array)
else:
# Optimized sliding window
cumsum = np.cumsum(conf_array)
window_sums = cumsum[window_size-1:]
window_sums[1:] -= cumsum[:-window_size]
window_means = window_sums / window_size
stats[f'min_sliding_{window_size}_mean_conf'] = np.min(window_means)
sorted_means = np.sort(window_means)
for percent in bottom_percentages:
idx = int(len(sorted_means) * percent)
stats[f'bottom_{percent}_sliding_{window_size}_mean_conf'] = sorted_means[:idx].mean()
return stats
# ============================================
# SECTION 4: Voting Strategies
# ============================================
def majority_vote(traces):
"""Perform majority voting based on extracted answers"""
if not traces:
return None, None
answer_counts = {}
answer_to_parsed = {}
for trace in traces:
extracted_answer = trace.get('extracted_answer')
parsed_answer = trace.get('parsed_answer')
if extracted_answer is not None:
answer_str = str(extracted_answer)
answer_counts[answer_str] = answer_counts.get(answer_str, 0) + 1
if answer_str not in answer_to_parsed:
answer_to_parsed[answer_str] = parsed_answer
if not answer_counts:
return None, None
voted_answer = max(answer_counts.keys(), key=lambda x: answer_counts[x])
voted_parsed = answer_to_parsed[voted_answer]
return voted_answer, voted_parsed
def weighted_majority_vote(traces, weight_key='mean_confidence'):
"""Perform weighted majority voting"""
if not traces:
return None, None
answer_weights = {}
answer_to_parsed = {}
for trace in traces:
extracted_answer = trace.get('extracted_answer')
parsed_answer = trace.get('parsed_answer')
weight = trace.get(weight_key)
if extracted_answer is not None and weight is not None:
answer_str = str(extracted_answer)
answer_weights[answer_str] = answer_weights.get(answer_str, 0.0) + float(weight)
if answer_str not in answer_to_parsed:
answer_to_parsed[answer_str] = parsed_answer
if not answer_weights:
return None, None
voted_answer = max(answer_weights.keys(), key=lambda x: answer_weights[x])
voted_parsed = answer_to_parsed[voted_answer]
return voted_answer, voted_parsed
def top_percent_vote(traces, weight_key='mean_confidence', top_percent=0.1, vote_strategy='majority'):
"""
First filter top percent of traces by weight_key, then perform voting
Args:
traces: List of trace dictionaries
weight_key: Key to use for filtering (e.g., 'mean_confidence')
top_percent: Percentage of top traces to keep (e.g., 0.1 for top 10%)
vote_strategy: 'majority' or 'weighted'
Returns:
voted_answer, voted_parsed
"""
if not traces:
return None, None
# Filter traces that have the weight_key and valid answers
valid_traces = [t for t in traces if weight_key in t and t.get('extracted_answer') is not None]
if not valid_traces:
return None, None
# Sort traces by weight_key in descending order (higher is better)
sorted_traces = sorted(valid_traces, key=lambda x: x[weight_key], reverse=True)
# Select top percent
n_top = max(1, int(len(sorted_traces) * top_percent))
top_traces = sorted_traces[:n_top]
# Apply voting strategy on filtered traces
if vote_strategy == 'majority':
return majority_vote(top_traces)
elif vote_strategy == 'weighted':
return weighted_majority_vote(top_traces, weight_key)
else:
raise ValueError(f"Unknown vote_strategy: {vote_strategy}")
# ============================================
# SECTION 5: JSONL Processing
# ============================================
def process_jsonl_file(file_path, ground_truth=None):
"""Process a single JSONL file and extract traces with metrics"""
traces = []
with open(file_path, 'r') as f:
lines = f.readlines()
for line_num, line in enumerate(lines):
if not line.strip():
continue
try:
data = json.loads(line)
# Extract response and confidence data
response = data.get('response', '')
mean_confidences = data.get('mean_confidences', [])
tokens = data.get('tokens', [])
question_meta = data['question_meta']['original_question']
# Extract answer
extracted_answer = extract_answer(response)
parsed_answer = parse_func(extracted_answer) if extracted_answer else None
# Get ground truth
if ground_truth is None:
# Try to extract from question_meta
for field in ['answer', 'solution', 'target']:
if field in question_meta:
ground_truth = str(question_meta[field]).strip()
break
# Calculate confidence statistics
conf_stats = calculate_confidence_stats(mean_confidences, tokens)
# Check correctness
is_correct = False
if extracted_answer is not None and ground_truth is not None:
is_correct = math_equal(extracted_answer, ground_truth)
# Create trace entry
trace = {
'trace_id': data.get('trace_id', line_num),
'extracted_answer': extracted_answer,
'parsed_answer': parsed_answer,
'is_correct': is_correct,
'ground_truth': ground_truth,
'response': response,
**conf_stats
}
traces.append(trace)
except Exception as e:
print(f"Error processing line {line_num} in {file_path}: {e}")
continue
return traces
def process_multiple_jsonls(file_pattern, ground_truth_map=None):
"""Process multiple JSONL files matching a pattern"""
files = glob.glob(file_pattern)
all_data = defaultdict(list)
for file_path in tqdm(files, desc="Processing JSONL files"):
# Extract question ID from filename if possible
filename = os.path.basename(file_path)
question_id = None
# Try to extract question ID (adjust pattern as needed)
if '_processed.jsonl' in filename:
try:
question_id = int(filename.replace('_processed.jsonl', ''))
except:
question_id = filename
else:
question_id = filename
# Get ground truth for this question
ground_truth = None
if ground_truth_map and question_id in ground_truth_map:
ground_truth = ground_truth_map[question_id]
# Process the file
traces = process_jsonl_file(file_path, ground_truth)
if traces:
all_data[question_id] = traces
return dict(all_data)
def process_multiple_dirs_jsonls(trace_dirs, file_pattern="*_processed.jsonl", ground_truth_map=None):
"""
Process JSONL files from multiple directories and merge traces with same filename
Args:
trace_dirs: List of directory paths to search for JSONL files
file_pattern: File pattern to match (e.g., "*_processed.jsonl")
ground_truth_map: Optional dictionary mapping question IDs to ground truth answers
Returns:
Dictionary where keys are question IDs and values are lists of merged traces
"""
all_data = defaultdict(list)
# First, collect all unique filenames across all directories
all_filenames = set()
dir_file_mapping = defaultdict(list) # Track which dirs have which files
for trace_dir in trace_dirs:
if not os.path.exists(trace_dir):
print(f"Warning: Directory {trace_dir} does not exist, skipping...")
continue
pattern = os.path.join(trace_dir, file_pattern)
files = glob.glob(pattern)
print(f"Found {len(files)} files in {trace_dir}")
for file_path in files:
filename = os.path.basename(file_path)
all_filenames.add(filename)
dir_file_mapping[filename].append(trace_dir)
print(f"Total unique filenames found: {len(all_filenames)}")
# Process each unique filename across all directories
for filename in tqdm(all_filenames, desc="Processing unique files"):
# Extract question ID from filename
question_id = None
if '_processed.jsonl' in filename:
try:
question_id = int(filename.replace('_processed.jsonl', ''))
except:
question_id = filename
else:
question_id = filename
# Get ground truth for this question
ground_truth = None
if ground_truth_map and question_id in ground_truth_map:
ground_truth = ground_truth_map[question_id]
# Collect traces from all directories for this filename
merged_traces = []
dirs_with_file = dir_file_mapping[filename]
for trace_dir in dirs_with_file:
file_path = os.path.join(trace_dir, filename)
if os.path.exists(file_path):
try:
traces = process_jsonl_file(file_path, ground_truth)
# Add directory info to each trace for identification
for i, trace in enumerate(traces):
trace['source_dir'] = trace_dir
trace['source_file'] = filename
# Create unique trace ID combining dir and original trace ID
original_trace_id = trace.get('trace_id', i)
trace['trace_id'] = f"{os.path.basename(trace_dir)}_{original_trace_id}"
merged_traces.extend(traces)
except Exception as e:
print(f"Error processing {file_path}: {e}")
continue
if merged_traces:
all_data[question_id] = merged_traces
print(f"Question {question_id}: Merged {len(merged_traces)} traces from {len(dirs_with_file)} directories")
return dict(all_data)
# ============================================
# SECTION 6: Analysis and Evaluation
# ============================================
def analyze_voting_performance(data, voting_sizes=[1, 2, 4, 8, 16, 32],
strategy='majority', weight_key='mean_confidence',
n_trials=1, seed=42, top_percent=None):
"""Analyze voting performance across different ensemble sizes"""
random.seed(seed)
np.random.seed(seed)
results = {}
for vote_size in voting_sizes:
accuracies = []
for trial in range(n_trials):
correct = 0
total = 0
for question_id, traces in data.items():
if len(traces) < vote_size:
continue
# Sample traces
sampled = random.sample(traces, vote_size)
# Apply voting strategy
if strategy == 'majority':
voted_answer, _ = majority_vote(sampled)
elif strategy == 'weighted':
voted_answer, _ = weighted_majority_vote(sampled, weight_key)
elif strategy == 'top_percent':
voted_answer, _ = top_percent_vote(sampled, weight_key, top_percent, 'majority')
elif strategy == 'top_percent_weighted':
voted_answer, _ = top_percent_vote(sampled, weight_key, top_percent, 'weighted')
else:
raise ValueError(f"Unknown strategy: {strategy}")
# Check correctness
if voted_answer is not None and sampled[0]['ground_truth'] is not None:
ground_truth = sampled[0]['ground_truth']
if math_equal(voted_answer, ground_truth):
correct += 1
total += 1
if total > 0:
accuracies.append(correct / total)
if accuracies:
results[vote_size] = {
'accuracy_mean': np.mean(accuracies),
'accuracy_std': np.std(accuracies)
}
return results
def analyze_top_percent_strategies(data, voting_sizes=[1, 2, 4, 8],
weight_keys=['mean_confidence', 'tail_2048_mean_conf'],
top_percents=[0.1, 0.2, 0.3, 0.5],
n_trials=1, seed=42):
"""
Comprehensive analysis of top percent filtering strategies
Args:
data: Processed trace data
voting_sizes: List of ensemble sizes to test
weight_keys: List of confidence metrics to use for filtering
top_percents: List of top percentages to test (e.g., [0.1, 0.2] for top 10%, 20%)
n_trials: Number of random trials per configuration
seed: Random seed for reproducibility
Returns:
Dictionary with results for each configuration
"""
results = {}
# Test each combination of parameters
for weight_key in weight_keys:
for top_percent in top_percents:
for vote_strategy in ['weighted']:
strategy_name = f"top_{int(top_percent*100)}%_{vote_strategy}_{weight_key}"
print(f"Testing {strategy_name}...")
strategy = 'top_percent' if vote_strategy == 'majority' else 'top_percent_weighted'
strategy_results = analyze_voting_performance(
data,
voting_sizes=voting_sizes,
strategy=strategy,
weight_key=weight_key,
n_trials=n_trials,
seed=seed,
top_percent=top_percent
)
results[strategy_name] = strategy_results
return results
def analyze_directory_distribution(data):
"""Analyze the distribution of traces across source directories"""
print("\n" + "="*60)
print("DIRECTORY DISTRIBUTION ANALYSIS")
print("="*60)
dir_stats = defaultdict(lambda: {'total_traces': 0, 'questions': set()})
for question_id, traces in data.items():
for trace in traces:
source_dir = trace.get('source_dir', 'unknown')
dir_stats[source_dir]['total_traces'] += 1
dir_stats[source_dir]['questions'].add(question_id)
print(f"{'Directory':<50} {'Traces':<10} {'Questions':<10}")
print("-" * 72)
for dir_name, stats in dir_stats.items():
short_name = os.path.basename(dir_name) if dir_name != 'unknown' else 'unknown'
print(f"{short_name:<50} {stats['total_traces']:<10} {len(stats['questions']):<10}")
return dir_stats
# ============================================
# MAIN EXECUTION CELL
# ============================================
def main_analysis_multi_dir(trace_dirs=None, file_pattern="*_processed.jsonl",
ground_truth_file=None, output_dir="./results"):
"""
Main analysis function for multiple directories
Args:
trace_dirs: List of directory paths to search for JSONL files
file_pattern: File pattern to match (e.g., "*_processed.jsonl")
ground_truth_file: Optional pickle file with ground truth answers
output_dir: Directory to save results
"""
if trace_dirs is None:
trace_dirs = TRACE_DIRS
print("="*60)
print("ENHANCED MULTI-DIRECTORY JSONL VOTING ANALYSIS")
print("="*60)
print(f"Processing {len(trace_dirs)} directories:")
for i, dir_path in enumerate(trace_dirs):
print(f" {i+1}. {dir_path}")
# Load ground truth if provided
ground_truth_map = None
if ground_truth_file and os.path.exists(ground_truth_file):
with open(ground_truth_file, 'rb') as f:
ground_truth_map = pickle.load(f)
print(f"Loaded ground truth from {ground_truth_file}")
# Process JSONL files from multiple directories
print(f"\nProcessing files with pattern: {file_pattern}")
data = process_multiple_dirs_jsonls(trace_dirs, file_pattern, ground_truth_map)
print(f"Processed {len(data)} questions")
total_traces = sum(len(traces) for traces in data.values())
print(f"Total traces: {total_traces}")
# Analyze directory distribution
dir_stats = analyze_directory_distribution(data)
# Calculate per-question statistics
print("\n" + "="*60)
print("PER-QUESTION STATISTICS")
print("="*60)
question_items = list(data.items())
for q_id, traces in question_items[:5]: # Show first 5 questions
correct = sum(1 for t in traces if t['is_correct'])
print(f"Question {q_id}: {correct}/{len(traces)} correct ({correct/len(traces):.2%})")
if traces:
mean_conf = np.mean([t.get('mean_confidence', 0) for t in traces if 'mean_confidence' in t])
print(f" Mean confidence: {mean_conf:.4f}")
# Show directory breakdown
dir_breakdown = defaultdict(int)
for trace in traces:
dir_name = os.path.basename(trace.get('source_dir', 'unknown'))
dir_breakdown[dir_name] += 1
print(f" Directory breakdown: {dict(dir_breakdown)}")
# Test baseline strategies
print("\n" + "="*60)
print("BASELINE VOTING STRATEGIES")
print("="*60)
baseline_strategies = [
('majority', 'Majority Vote', None),
('weighted', 'Weighted Vote (mean_conf)', 'mean_confidence'),
('weighted', 'Weighted Vote (tail_2048)', 'tail_2048_mean_conf'),
]
voting_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
all_results = {}
for strategy, name, weight_key in baseline_strategies:
print(f"\n{name}:")
results = analyze_voting_performance(
data,
voting_sizes=voting_sizes,
strategy=strategy,
weight_key=weight_key,
n_trials=10
)
all_results[name] = results
print(f"{'Size':<6} {'Accuracy':<12} {'Std Dev':<10}")
print("-" * 30)
for size in voting_sizes:
if size in results:
acc = results[size]['accuracy_mean']
std = results[size]['accuracy_std']
print(f"{size:<6} {acc:<12.4f} {std:<10.4f}")
# Test top percent filtering strategies
print("\n" + "="*60)
print("TOP PERCENT FILTERING STRATEGIES")
print("="*60)
top_percent_results = analyze_top_percent_strategies(
data,
voting_sizes=voting_sizes, # Use smaller sizes for top percent
weight_keys=['mean_confidence', 'tail_2048_mean_conf', 'bottom_0.1_sliding_2048_mean_conf'],
top_percents=[0.1, 0.9],
n_trials=10
)
all_results.update(top_percent_results)
# Display top percent results
for strategy_name, strategy_results in top_percent_results.items():
print(f"\n{strategy_name}:")
print(f"{'Size':<6} {'Accuracy':<12} {'Std Dev':<10}")
print("-" * 30)
for size in voting_sizes:
if size in strategy_results:
acc = strategy_results[size]['accuracy_mean']
std = strategy_results[size]['accuracy_std']
print(f"{size:<6} {acc:<12.4f} {std:<10.4f}")
# Save results
os.makedirs(output_dir, exist_ok=True)
# Save processed data
data_path = os.path.join(output_dir, "processed_data_multi_dir.pkl")
with open(data_path, 'wb') as f:
pickle.dump(data, f)
print(f"\n✓ Saved processed data to {data_path}")
# Save voting results
results_path = os.path.join(output_dir, "voting_results_multi_dir.pkl")
with open(results_path, 'wb') as f:
pickle.dump(all_results, f)
print(f"✓ Saved voting results to {results_path}")
# Create comprehensive summary DataFrame
summary_data = []
for strategy_name, strategy_results in all_results.items():
for size, metrics in strategy_results.items():
summary_data.append({
'Strategy': strategy_name,
'Ensemble Size': size,
'Accuracy': metrics['accuracy_mean'],
'Std Dev': metrics['accuracy_std']
})
df_summary = pd.DataFrame(summary_data)
csv_path = os.path.join(output_dir, "voting_summary_multi_dir.csv")
df_summary.to_csv(csv_path, index=False)
print(f"✓ Saved summary CSV to {csv_path}")
# Find best performing strategies
print("\n" + "="*60)
print("BEST PERFORMING STRATEGIES")
print("="*60)
# Group by ensemble size and find best accuracy for each
for size in voting_sizes:
size_results = df_summary[df_summary['Ensemble Size'] == size]
if not size_results.empty:
best_row = size_results.loc[size_results['Accuracy'].idxmax()]
print(f"Size {size}: {best_row['Strategy']} (Accuracy: {best_row['Accuracy']:.4f})")
return data, all_results, df_summary, dir_stats
# ============================================
# EXAMPLE USAGE
# ============================================
if __name__ == "__main__":
# Example usage - modify these paths for your data
# Use the predefined TRACE_DIRS or specify your own
custom_dirs = None # Set to your list of directories if different from TRACE_DIRS
file_pattern = "*_processed.jsonl" # Adjust this pattern
ground_truth_file = None # Optional: "./ground_truth.pkl"
output_directory = "./voting_results_multi_dir"
# Run the analysis
data, results, summary_df, dir_stats = main_analysis_multi_dir(
trace_dirs=custom_dirs, # Will use TRACE_DIRS if None
file_pattern=file_pattern,
ground_truth_file=ground_truth_file,
output_dir=output_directory
)
# Display summary
print("\n" + "="*60)
print("FINAL SUMMARY")
print("="*60)
print(summary_df.to_string(index=False))
# Show some example merged data
print("\n" + "="*60)
print("EXAMPLE MERGED DATA")
print("="*60)
for q_id, traces in list(data.items())[:2]: # Show first 2 questions
print(f"\nQuestion {q_id} ({len(traces)} traces):")
for trace in traces[:3]: # Show first 3 traces
source_dir = os.path.basename(trace.get('source_dir', 'unknown'))
correct = trace.get('is_correct', False)
conf = trace.get('mean_confidence', 0)
answer = trace.get('extracted_answer', 'None')
print(f" {source_dir}: {answer} (correct: {correct}, conf: {conf:.4f})")
Example 3: Online Generation
python
import openai
import json
from tqdm import tqdm
import time
import os
from datetime import datetime
import numpy as np
from collections import Counter
from dynasor.core.evaluator import math_equal
# ===========================
# Configuration
# ===========================
MODEL_PATH = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
MAX_TOKENS = 64000
RID = 0
QID = 0
PORT = 8000
DATASET_FILE = "aime25.jsonl"
# Online algorithm parameters
WARMUP_TRACES = 16
TOTAL_BUDGET = 32
CONFIDENCE_PERCENTILE = 90
WINDOW_SIZE = 2048
# ===========================
# Answer Extraction Functions
# ===========================
def extract_answer(text):
"""Extract boxed answer from text"""
if "boxed" in text:
ans = text.split("boxed")[-1]
if len(ans) == 0:
return ""
elif ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split("$")[0].strip()
return a.strip()
return None
def parse_func(s):
for f in [parse_latex, parse_expr, latex2sympy]:
try:
return f(s.replace("\\\\", "\\"))
except:
try:
return f(s)
except:
pass
return s
# ===========================
# Data Loading
# ===========================
def load_aime25_jsonl(file_path=DATASET_FILE):
"""Load data from aime25.jsonl file"""
data = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
data.append(json.loads(line.strip()))
return data
# ===========================
# Confidence Calculation
# ===========================
def compute_confidence(logprobs):
"""Compute confidence score from logprobs."""
confs = []
for lp in logprobs:
confs.append(round(-sum([l.logprob for l in lp]) / len(lp), 3))
return confs
def compute_least_grouped(confs, group_size=WINDOW_SIZE):
"""Compute sliding window mean confidence with specified group size."""
if len(confs) < group_size:
return [sum(confs) / len(confs)] if confs else [0]
sliding_means = []
for i in range(len(confs) - group_size + 1):
window = confs[i:i + group_size]
sliding_means.append(round(sum(window) / len(window), 3))
return sliding_means
# ===========================
# Voting Functions
# ===========================
def weighted_majority_vote(answers, weights):
"""Perform weighted majority voting"""
if not answers:
return None
answer_weights = {}
for answer, weight in zip(answers, weights):
if answer is not None:
answer_str = str(answer)
answer_weights[answer_str] = answer_weights.get(answer_str, 0.0) + float(weight)
if not answer_weights:
return None
voted_answer = max(answer_weights.keys(), key=lambda x: answer_weights[x])
return voted_answer
# ===========================
# Trace Processing and Accuracy Calculation
# ===========================
def process_trace(choice, trace_id, ground_truth):
"""Process a single trace and calculate accuracy"""
# Extract basic info
text = choice.message.content
tokens = [t.token for t in choice.logprobs.content]
# Calculate confidence
confs = compute_confidence([t.top_logprobs for t in choice.logprobs.content])
sliding_window = compute_least_grouped(confs, group_size=WINDOW_SIZE)
# Extract and parse answer
extracted_answer = extract_answer(text)
parsed_answer = parse_func(extracted_answer) if extracted_answer else None
# Calculate correctness
is_correct = False
if extracted_answer is not None and ground_truth is not None:
is_correct = math_equal(extracted_answer, ground_truth)
trace_data = {
"trace_id": trace_id,
"stop_reason": choice.stop_reason,
"finish_reason": choice.finish_reason,
"text": text,
"tokens": tokens,
"token_count": len(tokens),
"confs": confs,
"group_confs": sliding_window,
"min_conf": min(sliding_window) if sliding_window else 0,
"extracted_answer": extracted_answer,
"parsed_answer": parsed_answer,
"is_correct": is_correct,
}
return trace_data
def calculate_statistics(traces, phase_name=""):
"""Calculate statistics for a list of traces"""
if not traces:
return {}
total_traces = len(traces)
correct_traces = sum(1 for t in traces if t['is_correct'])
total_tokens = sum(t['token_count'] for t in traces)
# Confidence statistics
min_confs = [t['min_conf'] for t in traces]
stats = {
f"{phase_name}_traces": total_traces,
f"{phase_name}_correct": correct_traces,
f"{phase_name}_accuracy": correct_traces / total_traces if total_traces > 0 else 0,
f"{phase_name}_total_tokens": total_tokens,
f"{phase_name}_avg_tokens_per_trace": total_tokens / total_traces if total_traces > 0 else 0,
f"{phase_name}_min_conf_mean": np.mean(min_confs) if min_confs else 0,
f"{phase_name}_min_conf_std": np.std(min_confs) if min_confs else 0,
}
return stats
# ===========================
# Problem Processing Function (like problemx in reference)
# ===========================
def process_problem_voting(test_json, ground_truth):
"""
Process a problem result JSON and perform voting
Similar to problemx function in reference code
"""
answers = []
bar = test_json['conf_bar']
weights = []
tokens = 0
print(f"Warmup traces: {len(test_json['warmup_traces'])}, Final traces: {len(test_json['final_traces'])}")
# Process warmup traces
for i in range(len(test_json['warmup_traces'])):
answer = extract_answer(test_json['warmup_traces'][i]['text'])
minx = min(test_json['warmup_traces'][i]['group_confs'])
tokens += len(test_json['warmup_traces'][i]['tokens'])
if minx < bar:
continue
if answer is not None:
answers.append(answer)
weights.append(1) # Use weight 1 for consistency with reference
# Process final traces
for i in range(len(test_json['final_traces'])):
tokens += len(test_json['final_traces'][i]['tokens'])
# Skip traces stopped by gconf
if test_json['final_traces'][i]['stop_reason'] is not None and 'gconf' in test_json['final_traces'][i]['stop_reason']:
continue
answer = extract_answer(test_json['final_traces'][i]['text'])
minx = min(test_json['final_traces'][i]['group_confs'])
if answer is not None:
answers.append(answer)
weights.append(1)
# Perform voting
voted = weighted_majority_vote(answers, weights)
is_correct = str(voted) == str(ground_truth) if voted is not None else False
print(f'Bar: {bar:.4f}, Voted: {voted}, Ground truth: {ground_truth}, Correct: {is_correct}, Voting answers: {len(answers)}')
return is_correct, len(answers), tokens
# ===========================
# Main Function
# ===========================
def main():
print("="*60)
print("ONLINE ALGORITHM WITH ACCURACY CALCULATION")
print("="*60)
print(f"Model: {MODEL_PATH}")
print(f"Question ID: {QID}")
print(f"Run ID: {RID}")
print(f"Warmup traces: {WARMUP_TRACES}")
print(f"Total budget: {TOTAL_BUDGET}")
print(f"Confidence percentile: {CONFIDENCE_PERCENTILE}")
# Load the data
data = load_aime25_jsonl()
print(f"Loaded {len(data)} items from {DATASET_FILE}")
# Initialize client
client = openai.OpenAI(
api_key="None",
base_url=f"http://localhost:{PORT}/v1",
timeout=None
)
# Get question and ground truth
prompt = data[QID]['question']
ground_truth = str(data[QID].get('answer', '')).strip()
messages = [
{"role": "system", "content": "该助手为DeepSeek-R1,由深度求索公司创造。\n今天是2025年5月28日,星期一。\n"},
{"role": "user", "content": prompt}
]
print(f"\nQuestion: {prompt}")
print(f"Ground Truth: {ground_truth}")
# ===========================
# WARMUP PHASE
# ===========================
print(f"\n{'-'*40}")
print("WARMUP PHASE")
print(f"{'-'*40}")
t0 = time.time()
responses = client.chat.completions.create(
model=MODEL_PATH,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=0.6,
top_p=0.95,
logprobs=True,
top_logprobs=20,
n=WARMUP_TRACES,
extra_body={"top_k": 0},
)
t1 = time.time()
# Process warmup traces
warmup_traces = []
min_confs = []
for j in range(WARMUP_TRACES):
trace_data = process_trace(responses.choices[j], j, ground_truth)
warmup_traces.append(trace_data)
min_confs.append(trace_data["min_conf"])
# Calculate confidence bar
conf_bar = float(np.percentile(min_confs, CONFIDENCE_PERCENTILE))
# Calculate warmup statistics
warmup_stats = calculate_statistics(warmup_traces, "warmup")
print(f"Warmup time: {t1 - t0:.2f}s")
print(f"Confidence bar (P{CONFIDENCE_PERCENTILE}): {conf_bar:.4f}")
print(f"Warmup accuracy: {warmup_stats['warmup_accuracy']:.4f} ({warmup_stats['warmup_correct']}/{warmup_stats['warmup_traces']})")
print(f"Warmup total tokens: {warmup_stats['warmup_total_tokens']}")
print(f"Warmup avg tokens per trace: {warmup_stats['warmup_avg_tokens_per_trace']:.1f}")
# Show some example results
print(f"\nFirst 3 warmup traces:")
for i, trace in enumerate(warmup_traces[:3]):
print(f" Trace {i}: {trace['extracted_answer']} (correct: {trace['is_correct']}, conf: {trace['min_conf']:.4f}, tokens: {trace['token_count']})")
# ===========================
# FINAL PHASE
# ===========================
print(f"\n{'-'*40}")
print("FINAL PHASE (with early stopping)")
print(f"{'-'*40}")
real_gen = TOTAL_BUDGET - WARMUP_TRACES
t3 = time.time()
responses = client.chat.completions.create(
model=MODEL_PATH,
messages=messages,
max_tokens=MAX_TOKENS,
temperature=0.6,
top_p=0.95,
logprobs=True,
top_logprobs=20,
n=real_gen,
extra_body={
"top_k": 0,
"vllm_xargs": {
'enable_conf': True,
'window_size': WINDOW_SIZE,
'threshold': conf_bar
}
}
)
t4 = time.time()
# Process final traces
final_traces = []
for j in range(len(responses.choices)):
trace_data = process_trace(responses.choices[j], WARMUP_TRACES + j, ground_truth)
final_traces.append(trace_data)
# Calculate final statistics
final_stats = calculate_statistics(final_traces, "final")
print(f"Final time: {t4 - t3:.2f}s")
print(f"Final traces generated: {len(final_traces)} (requested: {real_gen})")
print(f"Final accuracy: {final_stats['final_accuracy']:.4f} ({final_stats['final_correct']}/{final_stats['final_traces']})")
print(f"Final total tokens: {final_stats['final_total_tokens']}")
print(f"Final avg tokens per trace: {final_stats['final_avg_tokens_per_trace']:.1f}")
# Show some example results
print(f"\nFirst 3 final traces:")
for i, trace in enumerate(final_traces[:3]):
print(f" Trace {i}: {trace['extracted_answer']} (correct: {trace['is_correct']}, conf: {trace['min_conf']:.4f}, tokens: {trace['token_count']})")
# ===========================
# OVERALL STATISTICS
# ===========================
print(f"\n{'-'*40}")
print("OVERALL STATISTICS")
print(f"{'-'*40}")
all_traces = warmup_traces + final_traces
overall_stats = calculate_statistics(all_traces, "overall")
total_time = t4 - t0
warmup_time = t1 - t0
final_time = t4 - t3
print(f"Total time: {total_time:.2f}s (warmup: {warmup_time:.2f}s, final: {final_time:.2f}s)")
print(f"Overall traces: {len(all_traces)}")
print(f"Overall accuracy: {overall_stats['overall_accuracy']:.4f} ({overall_stats['overall_correct']}/{overall_stats['overall_traces']})")
print(f"Overall total tokens: {overall_stats['overall_total_tokens']}")
print(f"Overall avg tokens per trace: {overall_stats['overall_avg_tokens_per_trace']:.1f}")
# Efficiency metrics
tokens_per_second = overall_stats['overall_total_tokens'] / total_time
traces_per_second = len(all_traces) / total_time
print(f"Tokens per second: {tokens_per_second:.1f}")
print(f"Traces per second: {traces_per_second:.2f}")
# Early stopping analysis
if final_traces:
above_threshold = sum(1 for t in final_traces if t['min_conf'] >= conf_bar)
print(f"Final traces above threshold: {above_threshold}/{len(final_traces)} ({above_threshold/len(final_traces):.2%})")
# ===========================
# SAVE RESULTS
# ===========================
results = {
"question_id": QID,
"run_id": RID,
"question": prompt,
"ground_truth": ground_truth,
"conf_bar": conf_bar,
"warmup_traces": warmup_traces,
"final_traces": final_traces,
"statistics": {
**warmup_stats,
**final_stats,
**overall_stats,
"total_time": total_time,
"warmup_time": warmup_time,
"final_time": final_time,
"tokens_per_second": tokens_per_second,
"traces_per_second": traces_per_second,
},
"config": {
"model_path": MODEL_PATH,
"warmup_traces": WARMUP_TRACES,
"total_budget": TOTAL_BUDGET,
"confidence_percentile": CONFIDENCE_PERCENTILE,
"window_size": WINDOW_SIZE,
},
"timestamp": datetime.now().isoformat(),
}
# Save results to JSON file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("outputs", exist_ok=True)
pickle.dump(results, open(f"outputs/q{QID}_r{RID}_{timestamp}.pkl", 'wb'))
print(f"Results saved to outputs/q{QID}_r{RID}_{timestamp}.pkl")
output_filename = f"outputs/q{QID}_r{RID}_{timestamp}.pkl"
print(f"\n✅ Results saved to outputs/q{QID}_r{RID}_{timestamp}.pkl")
# ===========================
# COMPARISON WITH BASELINE
# ===========================
print(f"\n{'-'*40}")
print("COMPARISON ANALYSIS")
print(f"{'-'*40}")
# Compare warmup vs final phase
if warmup_stats['warmup_traces'] > 0 and final_stats['final_traces'] > 0:
acc_improvement = final_stats['final_accuracy'] - warmup_stats['warmup_accuracy']
print(f"Accuracy change (final - warmup): {acc_improvement:+.4f}")
token_efficiency = final_stats['final_avg_tokens_per_trace'] / warmup_stats['warmup_avg_tokens_per_trace']
print(f"Token efficiency (final/warmup): {token_efficiency:.2f}x")
# Early stopping effectiveness
total_requested = WARMUP_TRACES + (TOTAL_BUDGET - WARMUP_TRACES)
total_generated = len(all_traces)
saving_ratio = (total_requested - total_generated) / total_requested
print(f"Traces requested: {total_requested}")
print(f"Traces generated: {total_generated}")
print(f"Early stopping saving: {saving_ratio:.2%}")
# ===========================
# VOTING FOR FINAL ANSWER
# ===========================
print(f"\n{'-'*40}")
print("VOTING FOR FINAL ANSWER")
print(f"{'-'*40}")
# Collect answers above threshold for voting
voting_answers = []
voting_weights = []
total_voting_tokens = 0
# Add warmup traces above threshold (use confidence as weight)
for trace in warmup_traces:
minx = trace['min_conf']
total_voting_tokens += trace['token_count']
if minx >= conf_bar:
answer = trace['extracted_answer']
if answer is not None:
voting_answers.append(answer)
voting_weights.append(minx)
# Add final traces above threshold (use weight 1)
for trace in final_traces:
total_voting_tokens += trace['token_count']
# Skip traces that were stopped by gconf (early stopping)
if trace['stop_reason'] is not None and 'gconf' in trace['stop_reason']:
continue
minx = trace['min_conf']
# Note: final traces might not need threshold check since they were already filtered
# But keeping it consistent with the reference code
if minx >= conf_bar:
answer = trace['extracted_answer']
if answer is not None:
voting_answers.append(answer)
voting_weights.append(1)
print(f"Traces used for voting: {len(voting_answers)}")
print(f"Total tokens in voting traces: {total_voting_tokens}")
# Perform weighted majority vote
voted_answer = weighted_majority_vote(voting_answers, voting_weights)
# Check if voted answer is correct
is_voted_correct = False
if voted_answer is not None and ground_truth:
is_voted_correct = str(voted_answer) == str(ground_truth)
# Also try math_equal for more robust comparison
try:
is_voted_correct_math = math_equal(voted_answer, ground_truth)
if is_voted_correct != is_voted_correct_math:
print(f"Warning: String comparison ({is_voted_correct}) != math comparison ({is_voted_correct_math})")
is_voted_correct = is_voted_correct_math # Use math_equal as more reliable
except:
pass # Fallback to string comparison
print(f"Voted answer: {voted_answer}")
print(f"Ground truth: {ground_truth}")
print(f"Voted answer correct: {is_voted_correct}")
# Show voting breakdown
if voting_answers:
answer_counts = Counter(voting_answers)
print(f"\nVoting breakdown:")
for answer, count in answer_counts.most_common():
print(f" {answer}: {count} votes")
# ===========================
# UPDATE RESULTS WITH VOTING
# ===========================
voting_results = {
"voting_answers": voting_answers,
"voting_weights": voting_weights,
"voted_answer": voted_answer,
"is_voted_correct": is_voted_correct,
"voting_traces_count": len(voting_answers),
"voting_total_tokens": total_voting_tokens,
}
results["voting"] = voting_results
results["statistics"]["voting_traces_count"] = len(voting_answers)
results["statistics"]["voting_total_tokens"] = total_voting_tokens
results["statistics"]["is_voted_correct"] = is_voted_correct
# Update the saved file with voting results
# with open(output_filename, 'w', encoding='utf-8') as f:
# pickle.dump(results, f)
# print(f"\n✅ Results updated with voting information: {output_filename}")
# ===========================
# FINAL SUMMARY
# ===========================
print(f"\n{'='*60}")
print("FINAL SUMMARY")
print(f"{'='*60}")
print(f"Question ID: {QID}")
print(f"Confidence bar: {conf_bar:.4f}")
print(f"Total traces generated: {len(all_traces)}")
print(f"Traces used for voting: {len(voting_answers)}")
print(f"Total tokens: {overall_stats['overall_total_tokens']}")
print(f"Voting answer: {voted_answer}")
print(f"Ground truth: {ground_truth}")
print(f"Final result: {'✅ CORRECT' if is_voted_correct else '❌ INCORRECT'}")
return results
if __name__ == "__main__":
results = main()
⚠️ Important Notes
- Make sure vLLM server is running before executing the examples
- Adjust the configuration parameters based on your specific requirements
- The examples use the Dynasor library for evaluation - ensure it's properly installed
- For production use, consider adding proper error handling and logging
Summary
This guide has provided you with:
- ✅ Complete instructions for modifying vLLM to implement DeepConf
- ✅ Alternative option to download the pre-implemented PR
- ✅ Three comprehensive examples for different use cases:
- Offline Generation with parallel processing
- Offline Voting with multiple strategies
- Online Generation with confidence-based early stopping
For more information and updates, visit the official PR on GitHub.