This document consolidates everything you need to use DeepConf on vLLM:
- How to modify vLLM to implement DeepConf - A minimal patch guide showing what to change, where, and why.
- Or skip the modifications - Directly download our PR with pre-implemented changes.
- Example code to run DeepConf - Complete working examples demonstrating the feature.
Option 1: Directly download our PR
- Related PR: https://github.com/vllm-project/vllm/pull/23201
Option 2: Modify vLLM
This README documents a minimal set of edits to integrate DeepConf (confidence-based early stopping) into vLLM, how to enable it via the OpenAI-compatible API, and pointers to the example notebook.
Tested Environment
- vLLM: commit
31f09c615f4f067dba765ce5fe7d00d880212a6d
- Python: 3.12.0
- CUDA: 12.8
High-Level Changes
We modify only two places in vLLM:
- Extend
LogprobsProcessor
to maintain a sliding-window confidence and exposecheck_conf_stop()
. - In
output_processor.py
, insert a single early-stop check before constructingRequestOutput
.
Enabling via OpenAI-Compatible API
The feature is toggled per request via the OpenAI-compatible chat.completions
endpoint. The arguments are passed through extra_body["vllm_xargs"]
and forwarded by vLLM to SamplingParams.extra_args
.
# Code: Enable confidence-based early stopping via OpenAI-compatible API responses = client.chat.completions.create( model=args.model_path, messages=messages, max_tokens=args.max_tokens, temperature=0.6, top_p=0.95, logprobs=True, top_logprobs=20, # request candidate logprobs (>=2) n=real_gen, extra_body={ "top_k": 0, "vllm_xargs": { "enable_conf": True, "window_size": 2048, "threshold": conf_threshold } } )
- The early-stop logic is inactive unless
logprobs=True
andtop_logprobs>=2
. window_size
is the confidence window length;threshold
is the cutoff used by our method.top_k=0
(optional) disables top-k truncation.
Exact Edits (Copy-Paste Guidance)
No patch tools are required; copy the snippets below into the indicated files. We recommend pinning to the commit above to avoid API drift.
File: vllm/v1/engine/logprobs.py
Step 1: Import (near the top):
from collections import deque from typing import Optional, List
Step 2: Extend the dataclass (class LogprobsProcessor
):
# --- fields for confidence-based early stopping --- conf_grouped: float conf_list: Optional[List[float]] conf_group_list: Optional[deque] conf_group_size: int conf_threshold: Optional[float]
Step 3: Initialize from the request (inside from_new_request(...)
, right before return cls(...)
):
if hasattr(request.sampling_params, "extra_args") \ and request.sampling_params.extra_args is not None \ and request.sampling_params.extra_args.get("enable_conf", False): conf_group_size = request.sampling_params.extra_args.get("window_size", 2048) conf_threshold = request.sampling_params.extra_args.get("threshold", 17) conf_grouped = 0.0 conf_group_list = deque(maxlen=conf_group_size) conf_list = [] else: conf_group_size = -1 conf_threshold = None conf_grouped = 0.0 conf_group_list = None conf_list = None
Then include the fields below in the return cls(...)
call:
conf_group_size=conf_group_size, conf_grouped=conf_grouped, conf_list=conf_list, conf_threshold=conf_threshold, conf_group_list=conf_group_list,
Step 4: Stop-check helper (add this method inside the class):
def check_conf_stop(self) -> bool: """Return True if the confidence window triggers early stopping.""" if self.conf_group_list is None or len(self.conf_group_list) == 0: return False # Require a full window; trigger when the moving average is below threshold. return (len(self.conf_group_list) >= self.conf_group_size and self.conf_grouped / len(self.conf_group_list) < self.conf_threshold)
Step 5: Update confidence during sampling (at the end of _update_sample_logprobs(...)
, after appending the logprob dict):
if self.conf_list is not None: # logprobs[0] is the sampled token; use the remaining candidates if len(logprobs) > 1: new_conf = -sum(logprobs[1:]) / len(logprobs[1:]) else: new_conf = 0.0 self.conf_list.append(new_conf) if len(self.conf_group_list) < self.conf_group_size: self.conf_group_list.append(new_conf) self.conf_grouped += new_conf else: self.conf_grouped -= self.conf_group_list.popleft() self.conf_group_list.append(new_conf) self.conf_grouped += new_conf
File: vllm/v1/engine/output_processor.py
Step 6: Invoke the stop-check in the decode loop.
Immediately after:
req_state.logprobs_processor.update_from_output(engine_core_output)
insert:
# Confidence-based early stopping (ours) if req_state.logprobs_processor.check_conf_stop(): finish_reason = FinishReason.STOP stop_reason = f"<gconf<{req_state.logprobs_processor.conf_threshold}>"
(Leave the subsequent logic that builds RequestOutput
unchanged.)
Additional Notes
- The feature is inactive unless
enable_conf=True
andlogprobs>0
(we usetop_logprobs=20
). - Confidence is the moving average of the negative mean candidate logprobs over a fixed window (
window_size
). - When triggered, we set
FinishReason.STOP
and annotatestop_reason
with<gconf<THR>>
for traceability.
Go to this PR and download the vllm version with the above changes.
Running Example
Install vLLM
Download the vllm and modify the vllm code as above, or download the version in the pr https://github.com/vllm-project/vllm/pull/23201 and build it as follows.
VLLM_USE_PRECOMPILED=1 uv pip install --editable .
Install dependencies for the example code
git clone https://github.com/hao-ai-lab/Dynasor.git cd Dynasor && pip install . && cd -
Example 1: Offline Generation
import openai import json from tqdm import tqdm import time import os import requests from datetime import datetime from transformers import AutoTokenizer import concurrent.futures import threading from functools import partial # =========================== # Model Configurations # =========================== MODEL_CONFIGS = { "Qwen/Qwen3-8B": { "temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_tokens": 32000, "template": "qwen3" }, "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B": { "temperature": 0.6, "top_p": 0.95, "top_k": 0, "max_tokens": 64000, "template": "dpsk_qwen_0528" }, "openai/gpt-oss-20b": { "temperature": 1.0, "top_p": 1.0, "top_k": 40, "max_tokens": 14000, "template": "gpt" }, # Add more model configurations as needed } # =========================== # Main Configuration # =========================== # Select your model MODEL_NAME = "Qwen/Qwen3-8B" # Change this to your desired model SAMPLES_PER_QUESTION = 4 # Number of traces to generate per question DATASET_FILE = "aime25.jsonl" # Input dataset file REASONING_EFFORT = "high" # For GPT models: low, medium, high # Parallel processing configuration MAX_WORKERS = 8 # Maximum number of concurrent workers (adjust based on your server capacity) MAX_WORKERS_PER_QUESTION = 4 # Maximum workers for traces within a single question # Get model-specific config model_config = MODEL_CONFIGS.get(MODEL_NAME) # General Configuration CONFIG = { "model_path": MODEL_NAME, "server_port": 8000, "temperature": model_config["temperature"], "top_p": model_config["top_p"], "top_k": model_config["top_k"], "max_tokens": model_config["max_tokens"], "template": model_config["template"], "reasoning_effort": REASONING_EFFORT, # Dataset and sampling configuration "dataset": DATASET_FILE, # Input dataset file "max_samples_per_question": SAMPLES_PER_QUESTION, # Number of traces per question "output_dir": f"output_{MODEL_NAME.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", # Parallel processing configuration "max_workers": MAX_WORKERS, "max_workers_per_question": MAX_WORKERS_PER_QUESTION, } # Thread-safe file writing lock file_lock = threading.Lock() # =========================== # Initialize OpenAI Client # =========================== # Note: Make sure vLLM server is already running on the specified port # Example command to start vLLM server: # vllm serve MODEL_NAME --port 8000 -tp 1 --gpu-memory-utilization 0.7 --enable-prefix-caching print(f"Connecting to vLLM server...") print(f"Model: {CONFIG['model_path']}") print(f"Server URL: http://localhost:{CONFIG['server_port']}/v1") print(f"Max concurrent workers: {CONFIG['max_workers']}") print(f"Max workers per question: {CONFIG['max_workers_per_question']}") # Initialize OpenAI client client = openai.OpenAI( api_key="None", base_url=f"http://localhost:{CONFIG['server_port']}/v1", timeout=None ) # Initialize tokenizer for GPT models if CONFIG['template'] == 'gpt': tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_path']) else: tokenizer = None # Test connection try: response = requests.get( f"http://localhost:{CONFIG['server_port']}/v1/models", headers={"Authorization": "Bearer None"}, ) if response.status_code == 200: print("✅ Successfully connected to vLLM server") else: print(f"⚠️ Server returned status code: {response.status_code}") except requests.exceptions.RequestException as e: print(f"❌ Failed to connect to vLLM server: {e}") print("Please ensure vLLM server is running on the specified port") # =========================== # Core Processing Functions # =========================== def get_gpt_token_probabilities(messages, max_tokens=50): """ Function to get token probabilities for GPT models using completions API """ response = client.completions.create( model=CONFIG['model_path'], prompt=messages, max_tokens=max_tokens, temperature=CONFIG['temperature'], top_p=CONFIG['top_p'], logprobs=20, extra_body={ "top_k": CONFIG['top_k'] }, ) # Extract generated text generated_text = response.choices[0].text # Extract token probabilities token_probs = [] mean_confs = [] tokens = [] log_probs = [] if response.choices[0].logprobs and response.choices[0].logprobs.tokens: for i, token_data in enumerate(response.choices[0].logprobs.tokens): step_probs = { "s": i, # step -> s "t": response.choices[0].logprobs.tokens[i], # generated_token -> t "lp": round(response.choices[0].logprobs.token_logprobs[i], 2), # logprob of generated token "a": [] # top_20_tokens -> a (alternatives) } # Add only top 5 alternatives to save space if response.choices[0].logprobs.top_logprobs: for tok, value in response.choices[0].logprobs.top_logprobs[i].items(): # Only top 5 step_probs["a"].append([ tok, round(value, 2) ]) # Use array instead of dict token_probs.append(step_probs) if step_probs['a']: mean_confs.append(round(-sum(p[1] for p in step_probs['a']) / len(step_probs['a']), 2)) else: mean_confs.append(0) tokens.append(response.choices[0].logprobs.tokens[i]) log_probs.append(round(response.choices[0].logprobs.token_logprobs[i], 2)) return { "text": generated_text, "probs": token_probs, # token_probabilities -> probs, "mean_confidences": mean_confs, # mean_confidences -> mean_confs "tokens": tokens, "log_probs": log_probs # log_probs -> log_probs } def get_token_probabilities(prompt, messages): """Get token probabilities from the vLLM server using chat completions API.""" response = client.chat.completions.create( model=CONFIG['model_path'], messages=messages, max_tokens=CONFIG['max_tokens'], temperature=CONFIG['temperature'], top_p=CONFIG['top_p'], logprobs=True, top_logprobs=20, extra_body={"top_k": CONFIG['top_k']}, ) generated_text = response.choices[0].message.content token_probs = [] mean_confs = [] tokens = [] log_probs = [] if response.choices[0].logprobs and response.choices[0].logprobs.content: for i, token_data in enumerate(response.choices[0].logprobs.content): step_probs = { "s": i, "t": token_data.token, "lp": round(token_data.logprob, 2), "a": [] } if token_data.top_logprobs: for logprob_data in token_data.top_logprobs[:5]: # Top 5 alternatives step_probs["a"].append([ logprob_data.token, round(logprob_data.logprob, 2) ]) token_probs.append(step_probs) if step_probs['a']: mean_confs.append(round(-sum(p[1] for p in step_probs['a']) / len(step_probs['a']), 2)) else: mean_confs.append(0) tokens.append(token_data.token) log_probs.append(round(token_data.logprob, 2)) return { "text": generated_text, "probs": token_probs, "mean_confidences": mean_confs, "tokens": tokens, "log_probs": log_probs } def prepare_messages(prompt, template): """Prepare messages based on template.""" if template == "dpsk_qwen_0528": return [ {"role": "system", "content": "该助手为DeepSeek-R1,由深度求索公司创造。\n今天是2025年5月28日,星期一。\n"}, {"role": "user", "content": prompt} ] elif template == 'qwen3': return [ {"role": "user", "content": prompt + "\nPlease reason step by step, and put your final answer within \\boxed{}."} ] elif template == 'gpt': # For GPT models, we'll prepare a simple string message first return prompt + "\nPlease reason step by step, and put your final answer within \\boxed{{}}." else: return [{"role": "user", "content": prompt}] def generate_single_trace(question_meta, trace_idx, output_dir): """Generate a single trace for a question. This function will be run in parallel.""" try: prompt = question_meta["prompt"] q_idx = question_meta["question_id"] messages = prepare_messages(prompt, CONFIG['template']) # Handle GPT models differently if CONFIG['template'] == 'gpt': # Apply chat template with reasoning effort for GPT models if tokenizer: formatted_messages = tokenizer.apply_chat_template( conversation=[ {"role": "user", "content": messages} ], add_generation_prompt=True, reasoning_effort=CONFIG['reasoning_effort'], tokenize=False, ) result = get_gpt_token_probabilities(messages=formatted_messages, max_tokens=CONFIG['max_tokens']) else: # Fallback if tokenizer is not available result = get_gpt_token_probabilities(messages=messages, max_tokens=CONFIG['max_tokens']) else: # Use chat completions for other models result = get_token_probabilities(prompt, messages) # Prepare trace data trace_data_processed = { "question_meta": question_meta, "trace_id": trace_idx, "response": result["text"], "tokens": result["tokens"], "mean_confidences": result["mean_confidences"], "log_probs": result["log_probs"], "messages": messages, } # Thread-safe file writing processed_file = os.path.join(output_dir, f"{q_idx}_processed.jsonl") with file_lock: with open(processed_file, "a", encoding="utf-8") as f: f.write(json.dumps(trace_data_processed, ensure_ascii=False) + "\n") return True, None except Exception as e: return False, f"Error in question {question_meta['question_id']}, trace {trace_idx}: {e}" def process_question_parallel(question, q_idx, output_dir): """Process a single question and generate multiple traces in parallel.""" prompt = question.get("problem", question.get("question", question.get("prompt", ""))) if not prompt: print(f"Warning: No prompt found in question {q_idx}") return 0 question_meta = { "question_id": q_idx, "original_question": question, "prompt": prompt, } # Generate traces in parallel completed_traces = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=CONFIG['max_workers_per_question']) as executor: # Submit all trace generation tasks future_to_trace = { executor.submit(generate_single_trace, question_meta, trace_idx, output_dir): trace_idx for trace_idx in range(CONFIG['max_samples_per_question']) } # Collect results as they complete for future in concurrent.futures.as_completed(future_to_trace): trace_idx = future_to_trace[future] try: success, error_msg = future.result() if success: completed_traces += 1 else: print(f"Error: {error_msg}") except Exception as e: print(f"Exception in trace {trace_idx}: {e}") return completed_traces def process_single_question_wrapper(args): """Wrapper function for processing a single question (needed for parallel execution).""" question, q_idx, output_dir = args return q_idx, process_question_parallel(question, q_idx, output_dir) def process_dataset_parallel(dataset_file, output_dir): """Process entire dataset with parallel processing.""" os.makedirs(output_dir, exist_ok=True) print(f"Created output directory: {output_dir}") # Load dataset questions = [] try: with open(dataset_file, "r", encoding="utf-8") as f: for line in f: questions.append(json.loads(line.strip())) print(f"Loaded {len(questions)} questions from {dataset_file}") except FileNotFoundError: print(f"Error: {dataset_file} not found!") return None # Process questions in parallel all_results = [] total_traces = 0 # Prepare arguments for parallel processing question_args = [(question, q_idx, output_dir) for q_idx, question in enumerate(questions)] print(f"Processing {len(questions)} questions with up to {CONFIG['max_workers']} parallel workers...") with concurrent.futures.ThreadPoolExecutor(max_workers=CONFIG['max_workers']) as executor: # Submit all question processing tasks future_to_question = { executor.submit(process_single_question_wrapper, args): args[1] for args in question_args } # Use tqdm to track progress with tqdm(total=len(questions), desc="Processing questions") as pbar: for future in concurrent.futures.as_completed(future_to_question): q_idx = future_to_question[future] try: result_q_idx, traces_completed = future.result() total_traces += traces_completed all_results.append({ "question_id": result_q_idx, "total_traces": traces_completed, "file_path": os.path.join(output_dir, f"{result_q_idx}.jsonl") }) pbar.update(1) pbar.set_postfix({ 'completed_traces': total_traces, 'avg_traces': f"{total_traces / len(all_results):.1f}" if all_results else "0" }) except Exception as e: print(f"Exception processing question {q_idx}: {e}") pbar.update(1) # Save summary summary_file = os.path.join(output_dir, "summary.json") summary = { "model": CONFIG['model_path'], "model_config": model_config, "dataset_file": dataset_file, "total_questions": len(questions), "completed_questions": len(all_results), "total_traces": total_traces, "average_traces_per_question": total_traces / len(all_results) if all_results else 0, "output_directory": output_dir, "timestamp": datetime.now().isoformat(), "reasoning_effort": CONFIG.get('reasoning_effort', 'N/A'), "parallel_config": { "max_workers": CONFIG['max_workers'], "max_workers_per_question": CONFIG['max_workers_per_question'] } } with open(summary_file, "w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) print(f"\n✅ Completed! Generated {total_traces} total traces") print(f"📁 Results saved to: {output_dir}") print(f"📊 Summary: {summary_file}") print(f"📈 Average traces per question: {total_traces / len(all_results):.1f}") return output_dir def check_results(output_dir): """Check and display results from output directory.""" if not os.path.exists(output_dir): print(f"Directory {output_dir} not found!") return # Load summary summary_file = os.path.join(output_dir, "summary.json") if os.path.exists(summary_file): with open(summary_file, "r", encoding="utf-8") as f: summary = json.load(f) print(f"\nSummary:") print(f" Model: {summary['model']}") print(f" Total questions: {summary['total_questions']}") print(f" Total traces: {summary['total_traces']}") print(f" Average traces per question: {summary['average_traces_per_question']:.1f}") print(f" Reasoning effort: {summary.get('reasoning_effort', 'N/A')}") if 'parallel_config' in summary: print(f" Max workers: {summary['parallel_config']['max_workers']}") print(f" Max workers per question: {summary['parallel_config']['max_workers_per_question']}") # Check individual files question_files = [f for f in os.listdir(output_dir) if f.endswith('.jsonl')] question_files.sort(key=lambda x: int(x.split('.')[0].split('_')[0]) if x.split('.')[0].split('_')[0].isdigit() else float('inf')) print(f"\nFound {len(question_files)} question files") # Show sample results for filename in question_files[:3]: if '_processed' in filename: continue filepath = os.path.join(output_dir, filename) if os.path.exists(filepath): with open(filepath, "r", encoding="utf-8") as f: lines = f.readlines() if lines: first_trace = json.loads(lines[0].strip()) print(f"\n{filename}:") print(f" Traces: {len(lines)}") print(f" First response preview: {first_trace['response'][:150]}...") # =========================== # Performance Monitoring # =========================== def monitor_performance(): """Monitor and suggest optimal worker configuration.""" import psutil cpu_count = psutil.cpu_count() memory_gb = psutil.virtual_memory().total / (1024**3) print(f"\n🔧 System Information:") print(f" CPU cores: {cpu_count}") print(f" Total memory: {memory_gb:.1f} GB") # Suggest optimal configuration suggested_workers = min(cpu_count * 2, 16) # Generally good for I/O bound tasks suggested_workers_per_q = min(4, suggested_workers // 2) print(f"\n💡 Suggested Configuration:") print(f" MAX_WORKERS: {suggested_workers}") print(f" MAX_WORKERS_PER_QUESTION: {suggested_workers_per_q}") if CONFIG['max_workers'] > suggested_workers: print(f"⚠️ Current MAX_WORKERS ({CONFIG['max_workers']}) might be too high for your system") return suggested_workers, suggested_workers_per_q # =========================== # Main Execution # =========================== # Monitor system performance monitor_performance() print(f"\n🚀 Starting parallel processing:") print(f" Model: {CONFIG['model_path']}") print(f" Template: {CONFIG['template']}") print(f" Dataset: {CONFIG['dataset']}") print(f" Traces per question: {CONFIG['max_samples_per_question']}") print(f" Max workers: {CONFIG['max_workers']}") print(f" Max workers per question: {CONFIG['max_workers_per_question']}") print(f" Output directory: {CONFIG['output_dir']}") if CONFIG['template'] == 'gpt': print(f" Reasoning effort: {CONFIG['reasoning_effort']}") start_time = time.time() # Process the dataset with parallel processing output_dir = process_dataset_parallel( dataset_file=CONFIG['dataset'], output_dir=CONFIG['output_dir'] ) end_time = time.time() processing_time = end_time - start_time # Check results if output_dir: print("\n" + "="*50) print("Checking generated results...") check_results(output_dir) print(f"\n⚡ Processing completed in {processing_time:.2f} seconds!") print(f"🚀 Speed improvement with parallel processing!") print("Note: vLLM server is still running. Stop it manually if needed.")
Example 2: Offline Voting
import os import json import pickle import numpy as np import pandas as pd from collections import Counter, defaultdict from tqdm import tqdm import glob import random from typing import Dict, List, Tuple, Optional, Any from dynasor.core.evaluator import math_equal TRACE_DIRS = [YOURDIR1, YOURDIR2, ...] # ============================================ # SECTION 1: Math Evaluation Functions # ============================================ def parse_func(s): for f in [parse_latex, parse_expr, latex2sympy]: try: return f(s.replace("\\\\", "\\")) except: try: return f(s) except: pass return s def quick_parse(text): """Quick parse to remove LaTeX text formatting""" if '\\text{' in text and '}' in text: while '\\text{' in text: start = text.find('\\text{') if start == -1: break end = text.find('}', start) if end == -1: break content = text[start + 6:end] text = text[:start] + content + text[end + 1:] return text # ============================================ # SECTION 2: Answer Extraction # ============================================ def extract_answer(text): """Extract boxed answer from text""" if "boxed" in text: ans = text.split("boxed")[-1] if len(ans) == 0: return "" elif ans[0] == "{": stack = 1 a = "" for c in ans[1:]: if c == "{": stack += 1 a += c elif c == "}": stack -= 1 if stack == 0: break a += c else: a += c else: a = ans.split("$")[0].strip() return a.strip() return None # ============================================ # SECTION 3: Confidence Metrics Calculation # ============================================ def calculate_confidence_stats(conf_list, tokens): """Calculate various confidence statistics""" if not conf_list: return {} assert len(conf_list) == len(tokens), "Confidence list and tokens must have same length" conf_array = np.array(conf_list) total_tokens = len(conf_array) stats = { 'mean_confidence': np.mean(conf_array) } # First/Last N tokens for n in [2048]: if total_tokens >= n: stats[f'tail_{n}_mean_conf'] = np.mean(conf_array[-n:]) else: stats[f'tail_{n}_mean_conf'] = np.mean(conf_array) # First/Last percentage for ratio in [0.1]: n_tokens = max(1, int(total_tokens * ratio)) stats[f'tail_{ratio}_mean_conf'] = np.mean(conf_array[-n_tokens:]) # Sliding window metrics window_sizes = [2048] bottom_percentages = [0.1, 0.5] for window_size in window_sizes: if total_tokens < window_size: stats[f'min_sliding_{window_size}_mean_conf'] = np.mean(conf_array) for percent in bottom_percentages: stats[f'bottom_{percent}_sliding_{window_size}_mean_conf'] = np.mean(conf_array) else: # Optimized sliding window cumsum = np.cumsum(conf_array) window_sums = cumsum[window_size-1:] window_sums[1:] -= cumsum[:-window_size] window_means = window_sums / window_size stats[f'min_sliding_{window_size}_mean_conf'] = np.min(window_means) sorted_means = np.sort(window_means) for percent in bottom_percentages: idx = int(len(sorted_means) * percent) stats[f'bottom_{percent}_sliding_{window_size}_mean_conf'] = sorted_means[:idx].mean() return stats # ============================================ # SECTION 4: Voting Strategies # ============================================ def majority_vote(traces): """Perform majority voting based on extracted answers""" if not traces: return None, None answer_counts = {} answer_to_parsed = {} for trace in traces: extracted_answer = trace.get('extracted_answer') parsed_answer = trace.get('parsed_answer') if extracted_answer is not None: answer_str = str(extracted_answer) answer_counts[answer_str] = answer_counts.get(answer_str, 0) + 1 if answer_str not in answer_to_parsed: answer_to_parsed[answer_str] = parsed_answer if not answer_counts: return None, None voted_answer = max(answer_counts.keys(), key=lambda x: answer_counts[x]) voted_parsed = answer_to_parsed[voted_answer] return voted_answer, voted_parsed def weighted_majority_vote(traces, weight_key='mean_confidence'): """Perform weighted majority voting""" if not traces: return None, None answer_weights = {} answer_to_parsed = {} for trace in traces: extracted_answer = trace.get('extracted_answer') parsed_answer = trace.get('parsed_answer') weight = trace.get(weight_key) if extracted_answer is not None and weight is not None: answer_str = str(extracted_answer) answer_weights[answer_str] = answer_weights.get(answer_str, 0.0) + float(weight) if answer_str not in answer_to_parsed: answer_to_parsed[answer_str] = parsed_answer if not answer_weights: return None, None voted_answer = max(answer_weights.keys(), key=lambda x: answer_weights[x]) voted_parsed = answer_to_parsed[voted_answer] return voted_answer, voted_parsed def top_percent_vote(traces, weight_key='mean_confidence', top_percent=0.1, vote_strategy='majority'): """ First filter top percent of traces by weight_key, then perform voting Args: traces: List of trace dictionaries weight_key: Key to use for filtering (e.g., 'mean_confidence') top_percent: Percentage of top traces to keep (e.g., 0.1 for top 10%) vote_strategy: 'majority' or 'weighted' Returns: voted_answer, voted_parsed """ if not traces: return None, None # Filter traces that have the weight_key and valid answers valid_traces = [t for t in traces if weight_key in t and t.get('extracted_answer') is not None] if not valid_traces: return None, None # Sort traces by weight_key in descending order (higher is better) sorted_traces = sorted(valid_traces, key=lambda x: x[weight_key], reverse=True) # Select top percent n_top = max(1, int(len(sorted_traces) * top_percent)) top_traces = sorted_traces[:n_top] # Apply voting strategy on filtered traces if vote_strategy == 'majority': return majority_vote(top_traces) elif vote_strategy == 'weighted': return weighted_majority_vote(top_traces, weight_key) else: raise ValueError(f"Unknown vote_strategy: {vote_strategy}") # ============================================ # SECTION 5: JSONL Processing # ============================================ def process_jsonl_file(file_path, ground_truth=None): """Process a single JSONL file and extract traces with metrics""" traces = [] with open(file_path, 'r') as f: lines = f.readlines() for line_num, line in enumerate(lines): if not line.strip(): continue try: data = json.loads(line) # Extract response and confidence data response = data.get('response', '') mean_confidences = data.get('mean_confidences', []) tokens = data.get('tokens', []) question_meta = data['question_meta']['original_question'] # Extract answer extracted_answer = extract_answer(response) parsed_answer = parse_func(extracted_answer) if extracted_answer else None # Get ground truth if ground_truth is None: # Try to extract from question_meta for field in ['answer', 'solution', 'target']: if field in question_meta: ground_truth = str(question_meta[field]).strip() break # Calculate confidence statistics conf_stats = calculate_confidence_stats(mean_confidences, tokens) # Check correctness is_correct = False if extracted_answer is not None and ground_truth is not None: is_correct = math_equal(extracted_answer, ground_truth) # Create trace entry trace = { 'trace_id': data.get('trace_id', line_num), 'extracted_answer': extracted_answer, 'parsed_answer': parsed_answer, 'is_correct': is_correct, 'ground_truth': ground_truth, 'response': response, **conf_stats } traces.append(trace) except Exception as e: print(f"Error processing line {line_num} in {file_path}: {e}") continue return traces def process_multiple_jsonls(file_pattern, ground_truth_map=None): """Process multiple JSONL files matching a pattern""" files = glob.glob(file_pattern) all_data = defaultdict(list) for file_path in tqdm(files, desc="Processing JSONL files"): # Extract question ID from filename if possible filename = os.path.basename(file_path) question_id = None # Try to extract question ID (adjust pattern as needed) if '_processed.jsonl' in filename: try: question_id = int(filename.replace('_processed.jsonl', '')) except: question_id = filename else: question_id = filename # Get ground truth for this question ground_truth = None if ground_truth_map and question_id in ground_truth_map: ground_truth = ground_truth_map[question_id] # Process the file traces = process_jsonl_file(file_path, ground_truth) if traces: all_data[question_id] = traces return dict(all_data) def process_multiple_dirs_jsonls(trace_dirs, file_pattern="*_processed.jsonl", ground_truth_map=None): """ Process JSONL files from multiple directories and merge traces with same filename Args: trace_dirs: List of directory paths to search for JSONL files file_pattern: File pattern to match (e.g., "*_processed.jsonl") ground_truth_map: Optional dictionary mapping question IDs to ground truth answers Returns: Dictionary where keys are question IDs and values are lists of merged traces """ all_data = defaultdict(list) # First, collect all unique filenames across all directories all_filenames = set() dir_file_mapping = defaultdict(list) # Track which dirs have which files for trace_dir in trace_dirs: if not os.path.exists(trace_dir): print(f"Warning: Directory {trace_dir} does not exist, skipping...") continue pattern = os.path.join(trace_dir, file_pattern) files = glob.glob(pattern) print(f"Found {len(files)} files in {trace_dir}") for file_path in files: filename = os.path.basename(file_path) all_filenames.add(filename) dir_file_mapping[filename].append(trace_dir) print(f"Total unique filenames found: {len(all_filenames)}") # Process each unique filename across all directories for filename in tqdm(all_filenames, desc="Processing unique files"): # Extract question ID from filename question_id = None if '_processed.jsonl' in filename: try: question_id = int(filename.replace('_processed.jsonl', '')) except: question_id = filename else: question_id = filename # Get ground truth for this question ground_truth = None if ground_truth_map and question_id in ground_truth_map: ground_truth = ground_truth_map[question_id] # Collect traces from all directories for this filename merged_traces = [] dirs_with_file = dir_file_mapping[filename] for trace_dir in dirs_with_file: file_path = os.path.join(trace_dir, filename) if os.path.exists(file_path): try: traces = process_jsonl_file(file_path, ground_truth) # Add directory info to each trace for identification for i, trace in enumerate(traces): trace['source_dir'] = trace_dir trace['source_file'] = filename # Create unique trace ID combining dir and original trace ID original_trace_id = trace.get('trace_id', i) trace['trace_id'] = f"{os.path.basename(trace_dir)}_{original_trace_id}" merged_traces.extend(traces) except Exception as e: print(f"Error processing {file_path}: {e}") continue if merged_traces: all_data[question_id] = merged_traces print(f"Question {question_id}: Merged {len(merged_traces)} traces from {len(dirs_with_file)} directories") return dict(all_data) # ============================================ # SECTION 6: Analysis and Evaluation # ============================================ def analyze_voting_performance(data, voting_sizes=[1, 2, 4, 8, 16, 32], strategy='majority', weight_key='mean_confidence', n_trials=1, seed=42, top_percent=None): """Analyze voting performance across different ensemble sizes""" random.seed(seed) np.random.seed(seed) results = {} for vote_size in voting_sizes: accuracies = [] for trial in range(n_trials): correct = 0 total = 0 for question_id, traces in data.items(): if len(traces) < vote_size: continue # Sample traces sampled = random.sample(traces, vote_size) # Apply voting strategy if strategy == 'majority': voted_answer, _ = majority_vote(sampled) elif strategy == 'weighted': voted_answer, _ = weighted_majority_vote(sampled, weight_key) elif strategy == 'top_percent': voted_answer, _ = top_percent_vote(sampled, weight_key, top_percent, 'majority') elif strategy == 'top_percent_weighted': voted_answer, _ = top_percent_vote(sampled, weight_key, top_percent, 'weighted') else: raise ValueError(f"Unknown strategy: {strategy}") # Check correctness if voted_answer is not None and sampled[0]['ground_truth'] is not None: ground_truth = sampled[0]['ground_truth'] if math_equal(voted_answer, ground_truth): correct += 1 total += 1 if total > 0: accuracies.append(correct / total) if accuracies: results[vote_size] = { 'accuracy_mean': np.mean(accuracies), 'accuracy_std': np.std(accuracies) } return results def analyze_top_percent_strategies(data, voting_sizes=[1, 2, 4, 8], weight_keys=['mean_confidence', 'tail_2048_mean_conf'], top_percents=[0.1, 0.2, 0.3, 0.5], n_trials=1, seed=42): """ Comprehensive analysis of top percent filtering strategies Args: data: Processed trace data voting_sizes: List of ensemble sizes to test weight_keys: List of confidence metrics to use for filtering top_percents: List of top percentages to test (e.g., [0.1, 0.2] for top 10%, 20%) n_trials: Number of random trials per configuration seed: Random seed for reproducibility Returns: Dictionary with results for each configuration """ results = {} # Test each combination of parameters for weight_key in weight_keys: for top_percent in top_percents: for vote_strategy in ['weighted']: strategy_name = f"top_{int(top_percent*100)}%_{vote_strategy}_{weight_key}" print(f"Testing {strategy_name}...") strategy = 'top_percent' if vote_strategy == 'majority' else 'top_percent_weighted' strategy_results = analyze_voting_performance( data, voting_sizes=voting_sizes, strategy=strategy, weight_key=weight_key, n_trials=n_trials, seed=seed, top_percent=top_percent ) results[strategy_name] = strategy_results return results def analyze_directory_distribution(data): """Analyze the distribution of traces across source directories""" print("\n" + "="*60) print("DIRECTORY DISTRIBUTION ANALYSIS") print("="*60) dir_stats = defaultdict(lambda: {'total_traces': 0, 'questions': set()}) for question_id, traces in data.items(): for trace in traces: source_dir = trace.get('source_dir', 'unknown') dir_stats[source_dir]['total_traces'] += 1 dir_stats[source_dir]['questions'].add(question_id) print(f"{'Directory':<50} {'Traces':<10} {'Questions':<10}") print("-" * 72) for dir_name, stats in dir_stats.items(): short_name = os.path.basename(dir_name) if dir_name != 'unknown' else 'unknown' print(f"{short_name:<50} {stats['total_traces']:<10} {len(stats['questions']):<10}") return dir_stats # ============================================ # MAIN EXECUTION CELL # ============================================ def main_analysis_multi_dir(trace_dirs=None, file_pattern="*_processed.jsonl", ground_truth_file=None, output_dir="./results"): """ Main analysis function for multiple directories Args: trace_dirs: List of directory paths to search for JSONL files file_pattern: File pattern to match (e.g., "*_processed.jsonl") ground_truth_file: Optional pickle file with ground truth answers output_dir: Directory to save results """ if trace_dirs is None: trace_dirs = TRACE_DIRS print("="*60) print("ENHANCED MULTI-DIRECTORY JSONL VOTING ANALYSIS") print("="*60) print(f"Processing {len(trace_dirs)} directories:") for i, dir_path in enumerate(trace_dirs): print(f" {i+1}. {dir_path}") # Load ground truth if provided ground_truth_map = None if ground_truth_file and os.path.exists(ground_truth_file): with open(ground_truth_file, 'rb') as f: ground_truth_map = pickle.load(f) print(f"Loaded ground truth from {ground_truth_file}") # Process JSONL files from multiple directories print(f"\nProcessing files with pattern: {file_pattern}") data = process_multiple_dirs_jsonls(trace_dirs, file_pattern, ground_truth_map) print(f"Processed {len(data)} questions") total_traces = sum(len(traces) for traces in data.values()) print(f"Total traces: {total_traces}") # Analyze directory distribution dir_stats = analyze_directory_distribution(data) # Calculate per-question statistics print("\n" + "="*60) print("PER-QUESTION STATISTICS") print("="*60) question_items = list(data.items()) for q_id, traces in question_items[:5]: # Show first 5 questions correct = sum(1 for t in traces if t['is_correct']) print(f"Question {q_id}: {correct}/{len(traces)} correct ({correct/len(traces):.2%})") if traces: mean_conf = np.mean([t.get('mean_confidence', 0) for t in traces if 'mean_confidence' in t]) print(f" Mean confidence: {mean_conf:.4f}") # Show directory breakdown dir_breakdown = defaultdict(int) for trace in traces: dir_name = os.path.basename(trace.get('source_dir', 'unknown')) dir_breakdown[dir_name] += 1 print(f" Directory breakdown: {dict(dir_breakdown)}") # Test baseline strategies print("\n" + "="*60) print("BASELINE VOTING STRATEGIES") print("="*60) baseline_strategies = [ ('majority', 'Majority Vote', None), ('weighted', 'Weighted Vote (mean_conf)', 'mean_confidence'), ('weighted', 'Weighted Vote (tail_2048)', 'tail_2048_mean_conf'), ] voting_sizes = [1, 2, 4, 8, 16, 32, 64, 128] all_results = {} for strategy, name, weight_key in baseline_strategies: print(f"\n{name}:") results = analyze_voting_performance( data, voting_sizes=voting_sizes, strategy=strategy, weight_key=weight_key, n_trials=10 ) all_results[name] = results print(f"{'Size':<6} {'Accuracy':<12} {'Std Dev':<10}") print("-" * 30) for size in voting_sizes: if size in results: acc = results[size]['accuracy_mean'] std = results[size]['accuracy_std'] print(f"{size:<6} {acc:<12.4f} {std:<10.4f}") # Test top percent filtering strategies print("\n" + "="*60) print("TOP PERCENT FILTERING STRATEGIES") print("="*60) top_percent_results = analyze_top_percent_strategies( data, voting_sizes=voting_sizes, # Use smaller sizes for top percent weight_keys=['mean_confidence', 'tail_2048_mean_conf', 'bottom_0.1_sliding_2048_mean_conf'], top_percents=[0.1, 0.9], n_trials=10 ) all_results.update(top_percent_results) # Display top percent results for strategy_name, strategy_results in top_percent_results.items(): print(f"\n{strategy_name}:") print(f"{'Size':<6} {'Accuracy':<12} {'Std Dev':<10}") print("-" * 30) for size in voting_sizes: if size in strategy_results: acc = strategy_results[size]['accuracy_mean'] std = strategy_results[size]['accuracy_std'] print(f"{size:<6} {acc:<12.4f} {std:<10.4f}") # Save results os.makedirs(output_dir, exist_ok=True) # Save processed data data_path = os.path.join(output_dir, "processed_data_multi_dir.pkl") with open(data_path, 'wb') as f: pickle.dump(data, f) print(f"\n✓ Saved processed data to {data_path}") # Save voting results results_path = os.path.join(output_dir, "voting_results_multi_dir.pkl") with open(results_path, 'wb') as f: pickle.dump(all_results, f) print(f"✓ Saved voting results to {results_path}") # Create comprehensive summary DataFrame summary_data = [] for strategy_name, strategy_results in all_results.items(): for size, metrics in strategy_results.items(): summary_data.append({ 'Strategy': strategy_name, 'Ensemble Size': size, 'Accuracy': metrics['accuracy_mean'], 'Std Dev': metrics['accuracy_std'] }) df_summary = pd.DataFrame(summary_data) csv_path = os.path.join(output_dir, "voting_summary_multi_dir.csv") df_summary.to_csv(csv_path, index=False) print(f"✓ Saved summary CSV to {csv_path}") # Find best performing strategies print("\n" + "="*60) print("BEST PERFORMING STRATEGIES") print("="*60) # Group by ensemble size and find best accuracy for each for size in voting_sizes: size_results = df_summary[df_summary['Ensemble Size'] == size] if not size_results.empty: best_row = size_results.loc[size_results['Accuracy'].idxmax()] print(f"Size {size}: {best_row['Strategy']} (Accuracy: {best_row['Accuracy']:.4f})") return data, all_results, df_summary, dir_stats # ============================================ # EXAMPLE USAGE # ============================================ if __name__ == "__main__": # Example usage - modify these paths for your data # Use the predefined TRACE_DIRS or specify your own custom_dirs = None # Set to your list of directories if different from TRACE_DIRS file_pattern = "*_processed.jsonl" # Adjust this pattern ground_truth_file = None # Optional: "./ground_truth.pkl" output_directory = "./voting_results_multi_dir" # Run the analysis data, results, summary_df, dir_stats = main_analysis_multi_dir( trace_dirs=custom_dirs, # Will use TRACE_DIRS if None file_pattern=file_pattern, ground_truth_file=ground_truth_file, output_dir=output_directory ) # Display summary print("\n" + "="*60) print("FINAL SUMMARY") print("="*60) print(summary_df.to_string(index=False)) # Show some example merged data print("\n" + "="*60) print("EXAMPLE MERGED DATA") print("="*60) for q_id, traces in list(data.items())[:2]: # Show first 2 questions print(f"\nQuestion {q_id} ({len(traces)} traces):") for trace in traces[:3]: # Show first 3 traces source_dir = os.path.basename(trace.get('source_dir', 'unknown')) correct = trace.get('is_correct', False) conf = trace.get('mean_confidence', 0) answer = trace.get('extracted_answer', 'None') print(f" {source_dir}: {answer} (correct: {correct}, conf: {conf:.4f})")
Example 3: Online Generation
python import openai import json from tqdm import tqdm import time import os from datetime import datetime import numpy as np from collections import Counter from dynasor.core.evaluator import math_equal # =========================== # Configuration # =========================== MODEL_PATH = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" MAX_TOKENS = 64000 RID = 0 QID = 0 PORT = 8000 DATASET_FILE = "aime25.jsonl" # Online algorithm parameters WARMUP_TRACES = 16 TOTAL_BUDGET = 32 CONFIDENCE_PERCENTILE = 90 WINDOW_SIZE = 2048 # =========================== # Answer Extraction Functions # =========================== def extract_answer(text): """Extract boxed answer from text""" if "boxed" in text: ans = text.split("boxed")[-1] if len(ans) == 0: return "" elif ans[0] == "{": stack = 1 a = "" for c in ans[1:]: if c == "{": stack += 1 a += c elif c == "}": stack -= 1 if stack == 0: break a += c else: a += c else: a = ans.split("$")[0].strip() return a.strip() return None def parse_func(s): for f in [parse_latex, parse_expr, latex2sympy]: try: return f(s.replace("\\\\", "\\")) except: try: return f(s) except: pass return s # =========================== # Data Loading # =========================== def load_aime25_jsonl(file_path=DATASET_FILE): """Load data from aime25.jsonl file""" data = [] with open(file_path, 'r', encoding='utf-8') as file: for line in file: data.append(json.loads(line.strip())) return data # =========================== # Confidence Calculation # =========================== def compute_confidence(logprobs): """Compute confidence score from logprobs.""" confs = [] for lp in logprobs: confs.append(round(-sum([l.logprob for l in lp]) / len(lp), 3)) return confs def compute_least_grouped(confs, group_size=WINDOW_SIZE): """Compute sliding window mean confidence with specified group size.""" if len(confs) < group_size: return [sum(confs) / len(confs)] if confs else [0] sliding_means = [] for i in range(len(confs) - group_size + 1): window = confs[i:i + group_size] sliding_means.append(round(sum(window) / len(window), 3)) return sliding_means # =========================== # Voting Functions # =========================== def weighted_majority_vote(answers, weights): """Perform weighted majority voting""" if not answers: return None answer_weights = {} for answer, weight in zip(answers, weights): if answer is not None: answer_str = str(answer) answer_weights[answer_str] = answer_weights.get(answer_str, 0.0) + float(weight) if not answer_weights: return None voted_answer = max(answer_weights.keys(), key=lambda x: answer_weights[x]) return voted_answer # =========================== # Trace Processing and Accuracy Calculation # =========================== def process_trace(choice, trace_id, ground_truth): """Process a single trace and calculate accuracy""" # Extract basic info text = choice.message.content tokens = [t.token for t in choice.logprobs.content] # Calculate confidence confs = compute_confidence([t.top_logprobs for t in choice.logprobs.content]) sliding_window = compute_least_grouped(confs, group_size=WINDOW_SIZE) # Extract and parse answer extracted_answer = extract_answer(text) parsed_answer = parse_func(extracted_answer) if extracted_answer else None # Calculate correctness is_correct = False if extracted_answer is not None and ground_truth is not None: is_correct = math_equal(extracted_answer, ground_truth) trace_data = { "trace_id": trace_id, "stop_reason": choice.stop_reason, "finish_reason": choice.finish_reason, "text": text, "tokens": tokens, "token_count": len(tokens), "confs": confs, "group_confs": sliding_window, "min_conf": min(sliding_window) if sliding_window else 0, "extracted_answer": extracted_answer, "parsed_answer": parsed_answer, "is_correct": is_correct, } return trace_data def calculate_statistics(traces, phase_name=""): """Calculate statistics for a list of traces""" if not traces: return {} total_traces = len(traces) correct_traces = sum(1 for t in traces if t['is_correct']) total_tokens = sum(t['token_count'] for t in traces) # Confidence statistics min_confs = [t['min_conf'] for t in traces] stats = { f"{phase_name}_traces": total_traces, f"{phase_name}_correct": correct_traces, f"{phase_name}_accuracy": correct_traces / total_traces if total_traces > 0 else 0, f"{phase_name}_total_tokens": total_tokens, f"{phase_name}_avg_tokens_per_trace": total_tokens / total_traces if total_traces > 0 else 0, f"{phase_name}_min_conf_mean": np.mean(min_confs) if min_confs else 0, f"{phase_name}_min_conf_std": np.std(min_confs) if min_confs else 0, } return stats # =========================== # Problem Processing Function (like problemx in reference) # =========================== def process_problem_voting(test_json, ground_truth): """ Process a problem result JSON and perform voting Similar to problemx function in reference code """ answers = [] bar = test_json['conf_bar'] weights = [] tokens = 0 print(f"Warmup traces: {len(test_json['warmup_traces'])}, Final traces: {len(test_json['final_traces'])}") # Process warmup traces for i in range(len(test_json['warmup_traces'])): answer = extract_answer(test_json['warmup_traces'][i]['text']) minx = min(test_json['warmup_traces'][i]['group_confs']) tokens += len(test_json['warmup_traces'][i]['tokens']) if minx < bar: continue if answer is not None: answers.append(answer) weights.append(1) # Use weight 1 for consistency with reference # Process final traces for i in range(len(test_json['final_traces'])): tokens += len(test_json['final_traces'][i]['tokens']) # Skip traces stopped by gconf if test_json['final_traces'][i]['stop_reason'] is not None and 'gconf' in test_json['final_traces'][i]['stop_reason']: continue answer = extract_answer(test_json['final_traces'][i]['text']) minx = min(test_json['final_traces'][i]['group_confs']) if answer is not None: answers.append(answer) weights.append(1) # Perform voting voted = weighted_majority_vote(answers, weights) is_correct = str(voted) == str(ground_truth) if voted is not None else False print(f'Bar: {bar:.4f}, Voted: {voted}, Ground truth: {ground_truth}, Correct: {is_correct}, Voting answers: {len(answers)}') return is_correct, len(answers), tokens # =========================== # Main Function # =========================== def main(): print("="*60) print("ONLINE ALGORITHM WITH ACCURACY CALCULATION") print("="*60) print(f"Model: {MODEL_PATH}") print(f"Question ID: {QID}") print(f"Run ID: {RID}") print(f"Warmup traces: {WARMUP_TRACES}") print(f"Total budget: {TOTAL_BUDGET}") print(f"Confidence percentile: {CONFIDENCE_PERCENTILE}") # Load the data data = load_aime25_jsonl() print(f"Loaded {len(data)} items from {DATASET_FILE}") # Initialize client client = openai.OpenAI( api_key="None", base_url=f"http://localhost:{PORT}/v1", timeout=None ) # Get question and ground truth prompt = data[QID]['question'] ground_truth = str(data[QID].get('answer', '')).strip() messages = [ {"role": "system", "content": "该助手为DeepSeek-R1,由深度求索公司创造。\n今天是2025年5月28日,星期一。\n"}, {"role": "user", "content": prompt} ] print(f"\nQuestion: {prompt}") print(f"Ground Truth: {ground_truth}") # =========================== # WARMUP PHASE # =========================== print(f"\n{'-'*40}") print("WARMUP PHASE") print(f"{'-'*40}") t0 = time.time() responses = client.chat.completions.create( model=MODEL_PATH, messages=messages, max_tokens=MAX_TOKENS, temperature=0.6, top_p=0.95, logprobs=True, top_logprobs=20, n=WARMUP_TRACES, extra_body={"top_k": 0}, ) t1 = time.time() # Process warmup traces warmup_traces = [] min_confs = [] for j in range(WARMUP_TRACES): trace_data = process_trace(responses.choices[j], j, ground_truth) warmup_traces.append(trace_data) min_confs.append(trace_data["min_conf"]) # Calculate confidence bar conf_bar = float(np.percentile(min_confs, CONFIDENCE_PERCENTILE)) # Calculate warmup statistics warmup_stats = calculate_statistics(warmup_traces, "warmup") print(f"Warmup time: {t1 - t0:.2f}s") print(f"Confidence bar (P{CONFIDENCE_PERCENTILE}): {conf_bar:.4f}") print(f"Warmup accuracy: {warmup_stats['warmup_accuracy']:.4f} ({warmup_stats['warmup_correct']}/{warmup_stats['warmup_traces']})") print(f"Warmup total tokens: {warmup_stats['warmup_total_tokens']}") print(f"Warmup avg tokens per trace: {warmup_stats['warmup_avg_tokens_per_trace']:.1f}") # Show some example results print(f"\nFirst 3 warmup traces:") for i, trace in enumerate(warmup_traces[:3]): print(f" Trace {i}: {trace['extracted_answer']} (correct: {trace['is_correct']}, conf: {trace['min_conf']:.4f}, tokens: {trace['token_count']})") # =========================== # FINAL PHASE # =========================== print(f"\n{'-'*40}") print("FINAL PHASE (with early stopping)") print(f"{'-'*40}") real_gen = TOTAL_BUDGET - WARMUP_TRACES t3 = time.time() responses = client.chat.completions.create( model=MODEL_PATH, messages=messages, max_tokens=MAX_TOKENS, temperature=0.6, top_p=0.95, logprobs=True, top_logprobs=20, n=real_gen, extra_body={ "top_k": 0, "vllm_xargs": { 'enable_conf': True, 'window_size': WINDOW_SIZE, 'threshold': conf_bar } } ) t4 = time.time() # Process final traces final_traces = [] for j in range(len(responses.choices)): trace_data = process_trace(responses.choices[j], WARMUP_TRACES + j, ground_truth) final_traces.append(trace_data) # Calculate final statistics final_stats = calculate_statistics(final_traces, "final") print(f"Final time: {t4 - t3:.2f}s") print(f"Final traces generated: {len(final_traces)} (requested: {real_gen})") print(f"Final accuracy: {final_stats['final_accuracy']:.4f} ({final_stats['final_correct']}/{final_stats['final_traces']})") print(f"Final total tokens: {final_stats['final_total_tokens']}") print(f"Final avg tokens per trace: {final_stats['final_avg_tokens_per_trace']:.1f}") # Show some example results print(f"\nFirst 3 final traces:") for i, trace in enumerate(final_traces[:3]): print(f" Trace {i}: {trace['extracted_answer']} (correct: {trace['is_correct']}, conf: {trace['min_conf']:.4f}, tokens: {trace['token_count']})") # =========================== # OVERALL STATISTICS # =========================== print(f"\n{'-'*40}") print("OVERALL STATISTICS") print(f"{'-'*40}") all_traces = warmup_traces + final_traces overall_stats = calculate_statistics(all_traces, "overall") total_time = t4 - t0 warmup_time = t1 - t0 final_time = t4 - t3 print(f"Total time: {total_time:.2f}s (warmup: {warmup_time:.2f}s, final: {final_time:.2f}s)") print(f"Overall traces: {len(all_traces)}") print(f"Overall accuracy: {overall_stats['overall_accuracy']:.4f} ({overall_stats['overall_correct']}/{overall_stats['overall_traces']})") print(f"Overall total tokens: {overall_stats['overall_total_tokens']}") print(f"Overall avg tokens per trace: {overall_stats['overall_avg_tokens_per_trace']:.1f}") # Efficiency metrics tokens_per_second = overall_stats['overall_total_tokens'] / total_time traces_per_second = len(all_traces) / total_time print(f"Tokens per second: {tokens_per_second:.1f}") print(f"Traces per second: {traces_per_second:.2f}") # Early stopping analysis if final_traces: above_threshold = sum(1 for t in final_traces if t['min_conf'] >= conf_bar) print(f"Final traces above threshold: {above_threshold}/{len(final_traces)} ({above_threshold/len(final_traces):.2%})") # =========================== # SAVE RESULTS # =========================== results = { "question_id": QID, "run_id": RID, "question": prompt, "ground_truth": ground_truth, "conf_bar": conf_bar, "warmup_traces": warmup_traces, "final_traces": final_traces, "statistics": { **warmup_stats, **final_stats, **overall_stats, "total_time": total_time, "warmup_time": warmup_time, "final_time": final_time, "tokens_per_second": tokens_per_second, "traces_per_second": traces_per_second, }, "config": { "model_path": MODEL_PATH, "warmup_traces": WARMUP_TRACES, "total_budget": TOTAL_BUDGET, "confidence_percentile": CONFIDENCE_PERCENTILE, "window_size": WINDOW_SIZE, }, "timestamp": datetime.now().isoformat(), } # Save results to JSON file timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") os.makedirs("outputs", exist_ok=True) pickle.dump(results, open(f"outputs/q{QID}_r{RID}_{timestamp}.pkl", 'wb')) print(f"Results saved to outputs/q{QID}_r{RID}_{timestamp}.pkl") output_filename = f"outputs/q{QID}_r{RID}_{timestamp}.pkl" print(f"\n✅ Results saved to outputs/q{QID}_r{RID}_{timestamp}.pkl") # =========================== # COMPARISON WITH BASELINE # =========================== print(f"\n{'-'*40}") print("COMPARISON ANALYSIS") print(f"{'-'*40}") # Compare warmup vs final phase if warmup_stats['warmup_traces'] > 0 and final_stats['final_traces'] > 0: acc_improvement = final_stats['final_accuracy'] - warmup_stats['warmup_accuracy'] print(f"Accuracy change (final - warmup): {acc_improvement:+.4f}") token_efficiency = final_stats['final_avg_tokens_per_trace'] / warmup_stats['warmup_avg_tokens_per_trace'] print(f"Token efficiency (final/warmup): {token_efficiency:.2f}x") # Early stopping effectiveness total_requested = WARMUP_TRACES + (TOTAL_BUDGET - WARMUP_TRACES) total_generated = len(all_traces) saving_ratio = (total_requested - total_generated) / total_requested print(f"Traces requested: {total_requested}") print(f"Traces generated: {total_generated}") print(f"Early stopping saving: {saving_ratio:.2%}") # =========================== # VOTING FOR FINAL ANSWER # =========================== print(f"\n{'-'*40}") print("VOTING FOR FINAL ANSWER") print(f"{'-'*40}") # Collect answers above threshold for voting voting_answers = [] voting_weights = [] total_voting_tokens = 0 # Add warmup traces above threshold (use confidence as weight) for trace in warmup_traces: minx = trace['min_conf'] total_voting_tokens += trace['token_count'] if minx >= conf_bar: answer = trace['extracted_answer'] if answer is not None: voting_answers.append(answer) voting_weights.append(minx) # Add final traces above threshold (use weight 1) for trace in final_traces: total_voting_tokens += trace['token_count'] # Skip traces that were stopped by gconf (early stopping) if trace['stop_reason'] is not None and 'gconf' in trace['stop_reason']: continue minx = trace['min_conf'] # Note: final traces might not need threshold check since they were already filtered # But keeping it consistent with the reference code if minx >= conf_bar: answer = trace['extracted_answer'] if answer is not None: voting_answers.append(answer) voting_weights.append(1) print(f"Traces used for voting: {len(voting_answers)}") print(f"Total tokens in voting traces: {total_voting_tokens}") # Perform weighted majority vote voted_answer = weighted_majority_vote(voting_answers, voting_weights) # Check if voted answer is correct is_voted_correct = False if voted_answer is not None and ground_truth: is_voted_correct = str(voted_answer) == str(ground_truth) # Also try math_equal for more robust comparison try: is_voted_correct_math = math_equal(voted_answer, ground_truth) if is_voted_correct != is_voted_correct_math: print(f"Warning: String comparison ({is_voted_correct}) != math comparison ({is_voted_correct_math})") is_voted_correct = is_voted_correct_math # Use math_equal as more reliable except: pass # Fallback to string comparison print(f"Voted answer: {voted_answer}") print(f"Ground truth: {ground_truth}") print(f"Voted answer correct: {is_voted_correct}") # Show voting breakdown if voting_answers: answer_counts = Counter(voting_answers) print(f"\nVoting breakdown:") for answer, count in answer_counts.most_common(): print(f" {answer}: {count} votes") # =========================== # UPDATE RESULTS WITH VOTING # =========================== voting_results = { "voting_answers": voting_answers, "voting_weights": voting_weights, "voted_answer": voted_answer, "is_voted_correct": is_voted_correct, "voting_traces_count": len(voting_answers), "voting_total_tokens": total_voting_tokens, } results["voting"] = voting_results results["statistics"]["voting_traces_count"] = len(voting_answers) results["statistics"]["voting_total_tokens"] = total_voting_tokens results["statistics"]["is_voted_correct"] = is_voted_correct # Update the saved file with voting results # with open(output_filename, 'w', encoding='utf-8') as f: # pickle.dump(results, f) # print(f"\n✅ Results updated with voting information: {output_filename}") # =========================== # FINAL SUMMARY # =========================== print(f"\n{'='*60}") print("FINAL SUMMARY") print(f"{'='*60}") print(f"Question ID: {QID}") print(f"Confidence bar: {conf_bar:.4f}") print(f"Total traces generated: {len(all_traces)}") print(f"Traces used for voting: {len(voting_answers)}") print(f"Total tokens: {overall_stats['overall_total_tokens']}") print(f"Voting answer: {voted_answer}") print(f"Ground truth: {ground_truth}") print(f"Final result: {'✅ CORRECT' if is_voted_correct else '❌ INCORRECT'}") return results if __name__ == "__main__": results = main()
⚠️ Important Notes
- Make sure vLLM server is running before executing the examples
- Adjust the configuration parameters based on your specific requirements
- The examples use the Dynasor library for evaluation - ensure it's properly installed
- For production use, consider adding proper error handling and logging
Summary
This guide has provided you with:
- ✅ Complete instructions for modifying vLLM to implement DeepConf
- ✅ Alternative option to download the pre-implemented PR
- ✅ Three comprehensive examples for different use cases:
- Offline Generation with parallel processing
- Offline Voting with multiple strategies
- Online Generation with confidence-based early stopping
For more information and updates, visit the official PR on GitHub.