🚀 DeepConf on vLLM

Minimal Patch & Usage Guide for Confidence-Based Early Stopping

This document consolidates everything you need to use DeepConf on vLLM:

Option 1: Directly download our PR

Option 2: Modify vLLM

This README documents a minimal set of edits to integrate DeepConf (confidence-based early stopping) into vLLM, how to enable it via the OpenAI-compatible API, and pointers to the example notebook.

Tested Environment

  • vLLM: commit 31f09c615f4f067dba765ce5fe7d00d880212a6d
  • Python: 3.12.0
  • CUDA: 12.8

High-Level Changes

We modify only two places in vLLM:

  1. Extend LogprobsProcessor to maintain a sliding-window confidence and expose check_conf_stop().
  2. In output_processor.py, insert a single early-stop check before constructing RequestOutput.

Enabling via OpenAI-Compatible API

The feature is toggled per request via the OpenAI-compatible chat.completions endpoint. The arguments are passed through extra_body["vllm_xargs"] and forwarded by vLLM to SamplingParams.extra_args.

# Code: Enable confidence-based early stopping via OpenAI-compatible API
responses = client.chat.completions.create(
    model=args.model_path,
    messages=messages,
    max_tokens=args.max_tokens,
    temperature=0.6,
    top_p=0.95,
    logprobs=True,
    top_logprobs=20,      # request candidate logprobs (>=2)
    n=real_gen,
    extra_body={
        "top_k": 0,
        "vllm_xargs": {
            "enable_conf": True,
            "window_size": 2048,
            "threshold": conf_threshold
        }
    }
)
Notes:
  • The early-stop logic is inactive unless logprobs=True and top_logprobs>=2.
  • window_size is the confidence window length; threshold is the cutoff used by our method.
  • top_k=0 (optional) disables top-k truncation.

Exact Edits (Copy-Paste Guidance)

No patch tools are required; copy the snippets below into the indicated files. We recommend pinning to the commit above to avoid API drift.

File: vllm/v1/engine/logprobs.py

Step 1: Import (near the top):

from collections import deque
from typing import Optional, List

Step 2: Extend the dataclass (class LogprobsProcessor):

# --- fields for confidence-based early stopping ---
conf_grouped: float
conf_list: Optional[List[float]]
conf_group_list: Optional[deque]
conf_group_size: int
conf_threshold: Optional[float]

Step 3: Initialize from the request (inside from_new_request(...), right before return cls(...)):

if hasattr(request.sampling_params, "extra_args") \
   and request.sampling_params.extra_args is not None \
   and request.sampling_params.extra_args.get("enable_conf", False):
    conf_group_size = request.sampling_params.extra_args.get("window_size", 2048)
    conf_threshold  = request.sampling_params.extra_args.get("threshold", 17)
    conf_grouped    = 0.0
    conf_group_list = deque(maxlen=conf_group_size)
    conf_list       = []
else:
    conf_group_size = -1
    conf_threshold  = None
    conf_grouped    = 0.0
    conf_group_list = None
    conf_list       = None

Then include the fields below in the return cls(...) call:

conf_group_size=conf_group_size,
conf_grouped=conf_grouped,
conf_list=conf_list,
conf_threshold=conf_threshold,
conf_group_list=conf_group_list,

Step 4: Stop-check helper (add this method inside the class):

def check_conf_stop(self) -> bool:
    """Return True if the confidence window triggers early stopping."""
    if self.conf_group_list is None or len(self.conf_group_list) == 0:
        return False
    # Require a full window; trigger when the moving average is below threshold.
    return (len(self.conf_group_list) >= self.conf_group_size
            and self.conf_grouped / len(self.conf_group_list) < self.conf_threshold)

Step 5: Update confidence during sampling (at the end of _update_sample_logprobs(...), after appending the logprob dict):

if self.conf_list is not None:
    # logprobs[0] is the sampled token; use the remaining candidates
    if len(logprobs) > 1:
        new_conf = -sum(logprobs[1:]) / len(logprobs[1:])
    else:
        new_conf = 0.0
    self.conf_list.append(new_conf)

    if len(self.conf_group_list) < self.conf_group_size:
        self.conf_group_list.append(new_conf)
        self.conf_grouped += new_conf
    else:
        self.conf_grouped -= self.conf_group_list.popleft()
        self.conf_group_list.append(new_conf)
        self.conf_grouped += new_conf

File: vllm/v1/engine/output_processor.py

Step 6: Invoke the stop-check in the decode loop.

Immediately after:

req_state.logprobs_processor.update_from_output(engine_core_output)

insert:

# Confidence-based early stopping (ours)
if req_state.logprobs_processor.check_conf_stop():
    finish_reason = FinishReason.STOP
    stop_reason = f"<gconf<{req_state.logprobs_processor.conf_threshold}>"

(Leave the subsequent logic that builds RequestOutput unchanged.)

Additional Notes

Go to this PR and download the vllm version with the above changes.

Running Example

Install vLLM

Download the vllm and modify the vllm code as above, or download the version in the pr https://github.com/vllm-project/vllm/pull/23201 and build it as follows.

VLLM_USE_PRECOMPILED=1 uv pip install --editable .

Install dependencies for the example code

git clone https://github.com/hao-ai-lab/Dynasor.git
cd Dynasor && pip install . && cd -

Example 1: Offline Generation

import openai
                    import json
                    from tqdm import tqdm
                    import time
                    import os
                    import requests
                    from datetime import datetime
                    from transformers import AutoTokenizer
                    import concurrent.futures
                    import threading
                    from functools import partial
                    
                    # ===========================
                    # Model Configurations
                    # ===========================
                    
                    MODEL_CONFIGS = {
                        "Qwen/Qwen3-8B": {
                            "temperature": 0.6,
                            "top_p": 0.95,
                            "top_k": 20,
                            "max_tokens": 32000,
                            "template": "qwen3"
                        },
                        "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B": {
                            "temperature": 0.6,
                            "top_p": 0.95,
                            "top_k": 0,
                            "max_tokens": 64000,
                            "template": "dpsk_qwen_0528"
                        },
                        "openai/gpt-oss-20b": {
                            "temperature": 1.0,
                            "top_p": 1.0,
                            "top_k": 40,
                            "max_tokens": 14000,
                            "template": "gpt"
                        },
                        # Add more model configurations as needed
                    }
                    
                    # ===========================
                    # Main Configuration
                    # ===========================
                    
                    # Select your model
                    MODEL_NAME = "Qwen/Qwen3-8B"  # Change this to your desired model
                    SAMPLES_PER_QUESTION = 4  # Number of traces to generate per question
                    DATASET_FILE = "aime25.jsonl"  # Input dataset file
                    REASONING_EFFORT = "high"  # For GPT models: low, medium, high
                    
                    # Parallel processing configuration
                    MAX_WORKERS = 8  # Maximum number of concurrent workers (adjust based on your server capacity)
                    MAX_WORKERS_PER_QUESTION = 4  # Maximum workers for traces within a single question
                    
                    # Get model-specific config
                    model_config = MODEL_CONFIGS.get(MODEL_NAME)
                    
                    # General Configuration
                    CONFIG = {
                        "model_path": MODEL_NAME,
                        "server_port": 8000,
                        "temperature": model_config["temperature"],
                        "top_p": model_config["top_p"],
                        "top_k": model_config["top_k"],
                        "max_tokens": model_config["max_tokens"],
                        "template": model_config["template"],
                        "reasoning_effort": REASONING_EFFORT,
                    
                        # Dataset and sampling configuration
                        "dataset": DATASET_FILE,  # Input dataset file
                        "max_samples_per_question": SAMPLES_PER_QUESTION,  # Number of traces per question
                        "output_dir": f"output_{MODEL_NAME.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
                    
                        # Parallel processing configuration
                        "max_workers": MAX_WORKERS,
                        "max_workers_per_question": MAX_WORKERS_PER_QUESTION,
                    }
                    
                    # Thread-safe file writing lock
                    file_lock = threading.Lock()
                    
                    # ===========================
                    # Initialize OpenAI Client
                    # ===========================
                    
                    # Note: Make sure vLLM server is already running on the specified port
                    # Example command to start vLLM server:
                    # vllm serve MODEL_NAME --port 8000 -tp 1 --gpu-memory-utilization 0.7 --enable-prefix-caching
                    
                    print(f"Connecting to vLLM server...")
                    print(f"Model: {CONFIG['model_path']}")
                    print(f"Server URL: http://localhost:{CONFIG['server_port']}/v1")
                    print(f"Max concurrent workers: {CONFIG['max_workers']}")
                    print(f"Max workers per question: {CONFIG['max_workers_per_question']}")
                    
                    # Initialize OpenAI client
                    client = openai.OpenAI(
                        api_key="None",
                        base_url=f"http://localhost:{CONFIG['server_port']}/v1",
                        timeout=None
                    )
                    
                    # Initialize tokenizer for GPT models
                    if CONFIG['template'] == 'gpt':
                        tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_path'])
                    else:
                        tokenizer = None
                    
                    # Test connection
                    try:
                        response = requests.get(
                            f"http://localhost:{CONFIG['server_port']}/v1/models",
                            headers={"Authorization": "Bearer None"},
                        )
                        if response.status_code == 200:
                            print("✅ Successfully connected to vLLM server")
                        else:
                            print(f"⚠️ Server returned status code: {response.status_code}")
                    except requests.exceptions.RequestException as e:
                        print(f"❌ Failed to connect to vLLM server: {e}")
                        print("Please ensure vLLM server is running on the specified port")
                    
                    # ===========================
                    # Core Processing Functions
                    # ===========================
                    
                    def get_gpt_token_probabilities(messages, max_tokens=50):
                        """
                        Function to get token probabilities for GPT models using completions API
                        """
                        response = client.completions.create(
                            model=CONFIG['model_path'],
                            prompt=messages,
                            max_tokens=max_tokens,
                            temperature=CONFIG['temperature'],
                            top_p=CONFIG['top_p'],
                            logprobs=20,
                            extra_body={
                                "top_k": CONFIG['top_k']
                            },
                        )
                    
                        # Extract generated text
                        generated_text = response.choices[0].text
                    
                        # Extract token probabilities
                        token_probs = []
                        mean_confs = []
                        tokens = []
                        log_probs = []
                    
                        if response.choices[0].logprobs and response.choices[0].logprobs.tokens:
                            for i, token_data in enumerate(response.choices[0].logprobs.tokens):
                                step_probs = {
                                    "s": i,  # step -> s
                                    "t": response.choices[0].logprobs.tokens[i],  # generated_token -> t
                                    "lp": round(response.choices[0].logprobs.token_logprobs[i], 2),  # logprob of generated token
                                    "a": []  # top_20_tokens -> a (alternatives)
                                }
                    
                                # Add only top 5 alternatives to save space
                                if response.choices[0].logprobs.top_logprobs:
                                    for tok, value in response.choices[0].logprobs.top_logprobs[i].items():  # Only top 5
                                        step_probs["a"].append([
                                            tok,
                                            round(value, 2)
                                        ])  # Use array instead of dict
                    
                                token_probs.append(step_probs)
                                if step_probs['a']:
                                    mean_confs.append(round(-sum(p[1] for p in step_probs['a']) / len(step_probs['a']), 2))
                                else:
                                    mean_confs.append(0)
                                tokens.append(response.choices[0].logprobs.tokens[i])
                                log_probs.append(round(response.choices[0].logprobs.token_logprobs[i], 2))
                    
                        return {
                            "text": generated_text,
                            "probs": token_probs,  # token_probabilities -> probs,
                            "mean_confidences": mean_confs,  # mean_confidences -> mean_confs
                            "tokens": tokens,
                            "log_probs": log_probs  # log_probs -> log_probs
                        }
                    
                    def get_token_probabilities(prompt, messages):
                        """Get token probabilities from the vLLM server using chat completions API."""
                        response = client.chat.completions.create(
                            model=CONFIG['model_path'],
                            messages=messages,
                            max_tokens=CONFIG['max_tokens'],
                            temperature=CONFIG['temperature'],
                            top_p=CONFIG['top_p'],
                            logprobs=True,
                            top_logprobs=20,
                            extra_body={"top_k": CONFIG['top_k']},
                        )
                    
                        generated_text = response.choices[0].message.content
                        token_probs = []
                        mean_confs = []
                        tokens = []
                        log_probs = []
                    
                        if response.choices[0].logprobs and response.choices[0].logprobs.content:
                            for i, token_data in enumerate(response.choices[0].logprobs.content):
                                step_probs = {
                                    "s": i,
                                    "t": token_data.token,
                                    "lp": round(token_data.logprob, 2),
                                    "a": []
                                }
                    
                                if token_data.top_logprobs:
                                    for logprob_data in token_data.top_logprobs[:5]:  # Top 5 alternatives
                                        step_probs["a"].append([
                                            logprob_data.token,
                                            round(logprob_data.logprob, 2)
                                        ])
                    
                                token_probs.append(step_probs)
                                if step_probs['a']:
                                    mean_confs.append(round(-sum(p[1] for p in step_probs['a']) / len(step_probs['a']), 2))
                                else:
                                    mean_confs.append(0)
                                tokens.append(token_data.token)
                                log_probs.append(round(token_data.logprob, 2))
                    
                        return {
                            "text": generated_text,
                            "probs": token_probs,
                            "mean_confidences": mean_confs,
                            "tokens": tokens,
                            "log_probs": log_probs
                        }
                    
                    def prepare_messages(prompt, template):
                        """Prepare messages based on template."""
                        if template == "dpsk_qwen_0528":
                            return [
                                {"role": "system", "content": "该助手为DeepSeek-R1,由深度求索公司创造。\n今天是2025年5月28日,星期一。\n"},
                                {"role": "user", "content": prompt}
                            ]
                        elif template == 'qwen3':
                            return [
                                {"role": "user", "content": prompt + "\nPlease reason step by step, and put your final answer within \\boxed{}."}
                            ]
                        elif template == 'gpt':
                            # For GPT models, we'll prepare a simple string message first
                            return prompt + "\nPlease reason step by step, and put your final answer within \\boxed{{}}."
                        else:
                            return [{"role": "user", "content": prompt}]
                    
                    def generate_single_trace(question_meta, trace_idx, output_dir):
                        """Generate a single trace for a question. This function will be run in parallel."""
                        try:
                            prompt = question_meta["prompt"]
                            q_idx = question_meta["question_id"]
                    
                            messages = prepare_messages(prompt, CONFIG['template'])
                    
                            # Handle GPT models differently
                            if CONFIG['template'] == 'gpt':
                                # Apply chat template with reasoning effort for GPT models
                                if tokenizer:
                                    formatted_messages = tokenizer.apply_chat_template(
                                        conversation=[
                                            {"role": "user", "content": messages}
                                        ],
                                        add_generation_prompt=True,
                                        reasoning_effort=CONFIG['reasoning_effort'],
                                        tokenize=False,
                                    )
                                    result = get_gpt_token_probabilities(messages=formatted_messages, max_tokens=CONFIG['max_tokens'])
                                else:
                                    # Fallback if tokenizer is not available
                                    result = get_gpt_token_probabilities(messages=messages, max_tokens=CONFIG['max_tokens'])
                            else:
                                # Use chat completions for other models
                                result = get_token_probabilities(prompt, messages)
                    
                            # Prepare trace data
                            trace_data_processed = {
                                "question_meta": question_meta,
                                "trace_id": trace_idx,
                                "response": result["text"],
                                "tokens": result["tokens"],
                                "mean_confidences": result["mean_confidences"],
                                "log_probs": result["log_probs"],
                                "messages": messages,
                            }
                    
                            # Thread-safe file writing
                            processed_file = os.path.join(output_dir, f"{q_idx}_processed.jsonl")
                            with file_lock:
                                with open(processed_file, "a", encoding="utf-8") as f:
                                    f.write(json.dumps(trace_data_processed, ensure_ascii=False) + "\n")
                    
                            return True, None
                    
                        except Exception as e:
                            return False, f"Error in question {question_meta['question_id']}, trace {trace_idx}: {e}"
                    
                    def process_question_parallel(question, q_idx, output_dir):
                        """Process a single question and generate multiple traces in parallel."""
                        prompt = question.get("problem", question.get("question", question.get("prompt", "")))
                    
                        if not prompt:
                            print(f"Warning: No prompt found in question {q_idx}")
                            return 0
                    
                        question_meta = {
                            "question_id": q_idx,
                            "original_question": question,
                            "prompt": prompt,
                        }
                    
                        # Generate traces in parallel
                        completed_traces = 0
                        with concurrent.futures.ThreadPoolExecutor(max_workers=CONFIG['max_workers_per_question']) as executor:
                            # Submit all trace generation tasks
                            future_to_trace = {
                                executor.submit(generate_single_trace, question_meta, trace_idx, output_dir): trace_idx
                                for trace_idx in range(CONFIG['max_samples_per_question'])
                            }
                    
                            # Collect results as they complete
                            for future in concurrent.futures.as_completed(future_to_trace):
                                trace_idx = future_to_trace[future]
                                try:
                                    success, error_msg = future.result()
                                    if success:
                                        completed_traces += 1
                                    else:
                                        print(f"Error: {error_msg}")
                                except Exception as e:
                                    print(f"Exception in trace {trace_idx}: {e}")
                    
                        return completed_traces
                    
                    def process_single_question_wrapper(args):
                        """Wrapper function for processing a single question (needed for parallel execution)."""
                        question, q_idx, output_dir = args
                        return q_idx, process_question_parallel(question, q_idx, output_dir)
                    
                    def process_dataset_parallel(dataset_file, output_dir):
                        """Process entire dataset with parallel processing."""
                        os.makedirs(output_dir, exist_ok=True)
                        print(f"Created output directory: {output_dir}")
                    
                        # Load dataset
                        questions = []
                        try:
                            with open(dataset_file, "r", encoding="utf-8") as f:
                                for line in f:
                                    questions.append(json.loads(line.strip()))
                            print(f"Loaded {len(questions)} questions from {dataset_file}")
                        except FileNotFoundError:
                            print(f"Error: {dataset_file} not found!")
                            return None
                    
                        # Process questions in parallel
                        all_results = []
                        total_traces = 0
                    
                        # Prepare arguments for parallel processing
                        question_args = [(question, q_idx, output_dir) for q_idx, question in enumerate(questions)]
                    
                        print(f"Processing {len(questions)} questions with up to {CONFIG['max_workers']} parallel workers...")
                    
                        with concurrent.futures.ThreadPoolExecutor(max_workers=CONFIG['max_workers']) as executor:
                            # Submit all question processing tasks
                            future_to_question = {
                                executor.submit(process_single_question_wrapper, args): args[1]
                                for args in question_args
                            }
                    
                            # Use tqdm to track progress
                            with tqdm(total=len(questions), desc="Processing questions") as pbar:
                                for future in concurrent.futures.as_completed(future_to_question):
                                    q_idx = future_to_question[future]
                                    try:
                                        result_q_idx, traces_completed = future.result()
                                        total_traces += traces_completed
                                        all_results.append({
                                            "question_id": result_q_idx,
                                            "total_traces": traces_completed,
                                            "file_path": os.path.join(output_dir, f"{result_q_idx}.jsonl")
                                        })
                                        pbar.update(1)
                                        pbar.set_postfix({
                                            'completed_traces': total_traces,
                                            'avg_traces': f"{total_traces / len(all_results):.1f}" if all_results else "0"
                                        })
                                    except Exception as e:
                                        print(f"Exception processing question {q_idx}: {e}")
                                        pbar.update(1)
                    
                        # Save summary
                        summary_file = os.path.join(output_dir, "summary.json")
                        summary = {
                            "model": CONFIG['model_path'],
                            "model_config": model_config,
                            "dataset_file": dataset_file,
                            "total_questions": len(questions),
                            "completed_questions": len(all_results),
                            "total_traces": total_traces,
                            "average_traces_per_question": total_traces / len(all_results) if all_results else 0,
                            "output_directory": output_dir,
                            "timestamp": datetime.now().isoformat(),
                            "reasoning_effort": CONFIG.get('reasoning_effort', 'N/A'),
                            "parallel_config": {
                                "max_workers": CONFIG['max_workers'],
                                "max_workers_per_question": CONFIG['max_workers_per_question']
                            }
                        }
                    
                        with open(summary_file, "w", encoding="utf-8") as f:
                            json.dump(summary, f, indent=2, ensure_ascii=False)
                    
                        print(f"\n✅ Completed! Generated {total_traces} total traces")
                        print(f"📁 Results saved to: {output_dir}")
                        print(f"📊 Summary: {summary_file}")
                        print(f"📈 Average traces per question: {total_traces / len(all_results):.1f}")
                    
                        return output_dir
                    
                    def check_results(output_dir):
                        """Check and display results from output directory."""
                        if not os.path.exists(output_dir):
                            print(f"Directory {output_dir} not found!")
                            return
                    
                        # Load summary
                        summary_file = os.path.join(output_dir, "summary.json")
                        if os.path.exists(summary_file):
                            with open(summary_file, "r", encoding="utf-8") as f:
                                summary = json.load(f)
                            print(f"\nSummary:")
                            print(f"  Model: {summary['model']}")
                            print(f"  Total questions: {summary['total_questions']}")
                            print(f"  Total traces: {summary['total_traces']}")
                            print(f"  Average traces per question: {summary['average_traces_per_question']:.1f}")
                            print(f"  Reasoning effort: {summary.get('reasoning_effort', 'N/A')}")
                            if 'parallel_config' in summary:
                                print(f"  Max workers: {summary['parallel_config']['max_workers']}")
                                print(f"  Max workers per question: {summary['parallel_config']['max_workers_per_question']}")
                    
                        # Check individual files
                        question_files = [f for f in os.listdir(output_dir) if f.endswith('.jsonl')]
                        question_files.sort(key=lambda x: int(x.split('.')[0].split('_')[0]) if x.split('.')[0].split('_')[0].isdigit() else float('inf'))
                    
                        print(f"\nFound {len(question_files)} question files")
                    
                        # Show sample results
                        for filename in question_files[:3]:
                            if '_processed' in filename:
                                continue
                    
                            filepath = os.path.join(output_dir, filename)
                            if os.path.exists(filepath):
                                with open(filepath, "r", encoding="utf-8") as f:
                                    lines = f.readlines()
                    
                                if lines:
                                    first_trace = json.loads(lines[0].strip())
                                    print(f"\n{filename}:")
                                    print(f"  Traces: {len(lines)}")
                                    print(f"  First response preview: {first_trace['response'][:150]}...")
                    
                    # ===========================
                    # Performance Monitoring
                    # ===========================
                    
                    def monitor_performance():
                        """Monitor and suggest optimal worker configuration."""
                        import psutil
                    
                        cpu_count = psutil.cpu_count()
                        memory_gb = psutil.virtual_memory().total / (1024**3)
                    
                        print(f"\n🔧 System Information:")
                        print(f"   CPU cores: {cpu_count}")
                        print(f"   Total memory: {memory_gb:.1f} GB")
                    
                        # Suggest optimal configuration
                        suggested_workers = min(cpu_count * 2, 16)  # Generally good for I/O bound tasks
                        suggested_workers_per_q = min(4, suggested_workers // 2)
                    
                        print(f"\n💡 Suggested Configuration:")
                        print(f"   MAX_WORKERS: {suggested_workers}")
                        print(f"   MAX_WORKERS_PER_QUESTION: {suggested_workers_per_q}")
                    
                        if CONFIG['max_workers'] > suggested_workers:
                            print(f"⚠️  Current MAX_WORKERS ({CONFIG['max_workers']}) might be too high for your system")
                    
                        return suggested_workers, suggested_workers_per_q
                    
                    # ===========================
                    # Main Execution
                    # ===========================
                    
                    # Monitor system performance
                    monitor_performance()
                    
                    print(f"\n🚀 Starting parallel processing:")
                    print(f"   Model: {CONFIG['model_path']}")
                    print(f"   Template: {CONFIG['template']}")
                    print(f"   Dataset: {CONFIG['dataset']}")
                    print(f"   Traces per question: {CONFIG['max_samples_per_question']}")
                    print(f"   Max workers: {CONFIG['max_workers']}")
                    print(f"   Max workers per question: {CONFIG['max_workers_per_question']}")
                    print(f"   Output directory: {CONFIG['output_dir']}")
                    if CONFIG['template'] == 'gpt':
                        print(f"   Reasoning effort: {CONFIG['reasoning_effort']}")
                    
                    start_time = time.time()
                    
                    # Process the dataset with parallel processing
                    output_dir = process_dataset_parallel(
                        dataset_file=CONFIG['dataset'],
                        output_dir=CONFIG['output_dir']
                    )
                    
                    end_time = time.time()
                    processing_time = end_time - start_time
                    
                    # Check results
                    if output_dir:
                        print("\n" + "="*50)
                        print("Checking generated results...")
                        check_results(output_dir)
                    
                    print(f"\n⚡ Processing completed in {processing_time:.2f} seconds!")
                    print(f"🚀 Speed improvement with parallel processing!")
                    print("Note: vLLM server is still running. Stop it manually if needed.")

Example 2: Offline Voting

import os
                    import json
                    import pickle
                    import numpy as np
                    import pandas as pd
                    from collections import Counter, defaultdict
                    from tqdm import tqdm
                    import glob
                    import random
                    from typing import Dict, List, Tuple, Optional, Any
                    from dynasor.core.evaluator import math_equal
                    
                    
                    TRACE_DIRS = [YOURDIR1, YOURDIR2, ...]
                    
                    # ============================================
                    # SECTION 1: Math Evaluation Functions
                    # ============================================
                    
                    
                    def parse_func(s):
                        for f in [parse_latex, parse_expr, latex2sympy]:
                            try:
                                return f(s.replace("\\\\", "\\"))
                            except:
                                try:
                                    return f(s)
                                except:
                                    pass
                        return s
                    
                    def quick_parse(text):
                        """Quick parse to remove LaTeX text formatting"""
                        if '\\text{' in text and '}' in text:
                            while '\\text{' in text:
                                start = text.find('\\text{')
                                if start == -1:
                                    break
                                end = text.find('}', start)
                                if end == -1:
                                    break
                                content = text[start + 6:end]
                                text = text[:start] + content + text[end + 1:]
                        return text
                    
                    # ============================================
                    # SECTION 2: Answer Extraction
                    # ============================================
                    
                    def extract_answer(text):
                        """Extract boxed answer from text"""
                        if "boxed" in text:
                            ans = text.split("boxed")[-1]
                            if len(ans) == 0:
                                return ""
                            elif ans[0] == "{":
                                stack = 1
                                a = ""
                                for c in ans[1:]:
                                    if c == "{":
                                        stack += 1
                                        a += c
                                    elif c == "}":
                                        stack -= 1
                                        if stack == 0:
                                            break
                                        a += c
                                    else:
                                        a += c
                            else:
                                a = ans.split("$")[0].strip()
                            return a.strip()
                        return None
                    
                    # ============================================
                    # SECTION 3: Confidence Metrics Calculation
                    # ============================================
                    
                    def calculate_confidence_stats(conf_list, tokens):
                        """Calculate various confidence statistics"""
                        if not conf_list:
                            return {}
                    
                        assert len(conf_list) == len(tokens), "Confidence list and tokens must have same length"
                    
                        conf_array = np.array(conf_list)
                        total_tokens = len(conf_array)
                    
                        stats = {
                            'mean_confidence': np.mean(conf_array)
                        }
                    
                        # First/Last N tokens
                        for n in [2048]:
                            if total_tokens >= n:
                                stats[f'tail_{n}_mean_conf'] = np.mean(conf_array[-n:])
                            else:
                                stats[f'tail_{n}_mean_conf'] = np.mean(conf_array)
                    
                        # First/Last percentage
                        for ratio in [0.1]:
                            n_tokens = max(1, int(total_tokens * ratio))
                            stats[f'tail_{ratio}_mean_conf'] = np.mean(conf_array[-n_tokens:])
                    
                        # Sliding window metrics
                        window_sizes = [2048]
                        bottom_percentages = [0.1, 0.5]
                    
                        for window_size in window_sizes:
                            if total_tokens < window_size:
                                stats[f'min_sliding_{window_size}_mean_conf'] = np.mean(conf_array)
                                for percent in bottom_percentages:
                                    stats[f'bottom_{percent}_sliding_{window_size}_mean_conf'] = np.mean(conf_array)
                            else:
                                # Optimized sliding window
                                cumsum = np.cumsum(conf_array)
                                window_sums = cumsum[window_size-1:]
                                window_sums[1:] -= cumsum[:-window_size]
                                window_means = window_sums / window_size
                                stats[f'min_sliding_{window_size}_mean_conf'] = np.min(window_means)
                    
                                sorted_means = np.sort(window_means)
                    
                                for percent in bottom_percentages:
                                    idx = int(len(sorted_means) * percent)
                                    stats[f'bottom_{percent}_sliding_{window_size}_mean_conf'] = sorted_means[:idx].mean()
                    
                        return stats
                    
                    # ============================================
                    # SECTION 4: Voting Strategies
                    # ============================================
                    
                    def majority_vote(traces):
                        """Perform majority voting based on extracted answers"""
                        if not traces:
                            return None, None
                    
                        answer_counts = {}
                        answer_to_parsed = {}
                    
                        for trace in traces:
                            extracted_answer = trace.get('extracted_answer')
                            parsed_answer = trace.get('parsed_answer')
                    
                            if extracted_answer is not None:
                                answer_str = str(extracted_answer)
                                answer_counts[answer_str] = answer_counts.get(answer_str, 0) + 1
                                if answer_str not in answer_to_parsed:
                                    answer_to_parsed[answer_str] = parsed_answer
                    
                        if not answer_counts:
                            return None, None
                    
                        voted_answer = max(answer_counts.keys(), key=lambda x: answer_counts[x])
                        voted_parsed = answer_to_parsed[voted_answer]
                    
                        return voted_answer, voted_parsed
                    
                    def weighted_majority_vote(traces, weight_key='mean_confidence'):
                        """Perform weighted majority voting"""
                        if not traces:
                            return None, None
                    
                        answer_weights = {}
                        answer_to_parsed = {}
                    
                        for trace in traces:
                            extracted_answer = trace.get('extracted_answer')
                            parsed_answer = trace.get('parsed_answer')
                            weight = trace.get(weight_key)
                    
                            if extracted_answer is not None and weight is not None:
                                answer_str = str(extracted_answer)
                                answer_weights[answer_str] = answer_weights.get(answer_str, 0.0) + float(weight)
                                if answer_str not in answer_to_parsed:
                                    answer_to_parsed[answer_str] = parsed_answer
                    
                        if not answer_weights:
                            return None, None
                    
                        voted_answer = max(answer_weights.keys(), key=lambda x: answer_weights[x])
                        voted_parsed = answer_to_parsed[voted_answer]
                    
                        return voted_answer, voted_parsed
                    
                    def top_percent_vote(traces, weight_key='mean_confidence', top_percent=0.1, vote_strategy='majority'):
                        """
                        First filter top percent of traces by weight_key, then perform voting
                    
                        Args:
                            traces: List of trace dictionaries
                            weight_key: Key to use for filtering (e.g., 'mean_confidence')
                            top_percent: Percentage of top traces to keep (e.g., 0.1 for top 10%)
                            vote_strategy: 'majority' or 'weighted'
                    
                        Returns:
                            voted_answer, voted_parsed
                        """
                        if not traces:
                            return None, None
                    
                        # Filter traces that have the weight_key and valid answers
                        valid_traces = [t for t in traces if weight_key in t and t.get('extracted_answer') is not None]
                    
                        if not valid_traces:
                            return None, None
                    
                        # Sort traces by weight_key in descending order (higher is better)
                        sorted_traces = sorted(valid_traces, key=lambda x: x[weight_key], reverse=True)
                    
                        # Select top percent
                        n_top = max(1, int(len(sorted_traces) * top_percent))
                        top_traces = sorted_traces[:n_top]
                    
                        # Apply voting strategy on filtered traces
                        if vote_strategy == 'majority':
                            return majority_vote(top_traces)
                        elif vote_strategy == 'weighted':
                            return weighted_majority_vote(top_traces, weight_key)
                        else:
                            raise ValueError(f"Unknown vote_strategy: {vote_strategy}")
                    
                    # ============================================
                    # SECTION 5: JSONL Processing
                    # ============================================
                    
                    def process_jsonl_file(file_path, ground_truth=None):
                        """Process a single JSONL file and extract traces with metrics"""
                        traces = []
                    
                        with open(file_path, 'r') as f:
                            lines = f.readlines()
                    
                        for line_num, line in enumerate(lines):
                            if not line.strip():
                                continue
                    
                            try:
                                data = json.loads(line)
                    
                                # Extract response and confidence data
                                response = data.get('response', '')
                                mean_confidences = data.get('mean_confidences', [])
                                tokens = data.get('tokens', [])
                                question_meta = data['question_meta']['original_question']
                    
                                # Extract answer
                                extracted_answer = extract_answer(response)
                                parsed_answer = parse_func(extracted_answer) if extracted_answer else None
                    
                                # Get ground truth
                                if ground_truth is None:
                                    # Try to extract from question_meta
                                    for field in ['answer', 'solution', 'target']:
                                        if field in question_meta:
                                            ground_truth = str(question_meta[field]).strip()
                                            break
                    
                                # Calculate confidence statistics
                                conf_stats = calculate_confidence_stats(mean_confidences, tokens)
                    
                                # Check correctness
                                is_correct = False
                                if extracted_answer is not None and ground_truth is not None:
                                    is_correct = math_equal(extracted_answer, ground_truth)
                    
                                # Create trace entry
                                trace = {
                                    'trace_id': data.get('trace_id', line_num),
                                    'extracted_answer': extracted_answer,
                                    'parsed_answer': parsed_answer,
                                    'is_correct': is_correct,
                                    'ground_truth': ground_truth,
                                    'response': response,
                                    **conf_stats
                                }
                    
                                traces.append(trace)
                    
                            except Exception as e:
                                print(f"Error processing line {line_num} in {file_path}: {e}")
                                continue
                    
                        return traces
                    
                    def process_multiple_jsonls(file_pattern, ground_truth_map=None):
                        """Process multiple JSONL files matching a pattern"""
                        files = glob.glob(file_pattern)
                        all_data = defaultdict(list)
                    
                        for file_path in tqdm(files, desc="Processing JSONL files"):
                            # Extract question ID from filename if possible
                            filename = os.path.basename(file_path)
                            question_id = None
                    
                            # Try to extract question ID (adjust pattern as needed)
                            if '_processed.jsonl' in filename:
                                try:
                                    question_id = int(filename.replace('_processed.jsonl', ''))
                                except:
                                    question_id = filename
                            else:
                                question_id = filename
                    
                            # Get ground truth for this question
                            ground_truth = None
                            if ground_truth_map and question_id in ground_truth_map:
                                ground_truth = ground_truth_map[question_id]
                    
                            # Process the file
                            traces = process_jsonl_file(file_path, ground_truth)
                    
                            if traces:
                                all_data[question_id] = traces
                    
                        return dict(all_data)
                    
                    def process_multiple_dirs_jsonls(trace_dirs, file_pattern="*_processed.jsonl", ground_truth_map=None):
                        """
                        Process JSONL files from multiple directories and merge traces with same filename
                    
                        Args:
                            trace_dirs: List of directory paths to search for JSONL files
                            file_pattern: File pattern to match (e.g., "*_processed.jsonl")
                            ground_truth_map: Optional dictionary mapping question IDs to ground truth answers
                    
                        Returns:
                            Dictionary where keys are question IDs and values are lists of merged traces
                        """
                        all_data = defaultdict(list)
                    
                        # First, collect all unique filenames across all directories
                        all_filenames = set()
                        dir_file_mapping = defaultdict(list)  # Track which dirs have which files
                    
                        for trace_dir in trace_dirs:
                            if not os.path.exists(trace_dir):
                                print(f"Warning: Directory {trace_dir} does not exist, skipping...")
                                continue
                    
                            pattern = os.path.join(trace_dir, file_pattern)
                            files = glob.glob(pattern)
                            print(f"Found {len(files)} files in {trace_dir}")
                    
                            for file_path in files:
                                filename = os.path.basename(file_path)
                                all_filenames.add(filename)
                                dir_file_mapping[filename].append(trace_dir)
                    
                        print(f"Total unique filenames found: {len(all_filenames)}")
                    
                        # Process each unique filename across all directories
                        for filename in tqdm(all_filenames, desc="Processing unique files"):
                            # Extract question ID from filename
                            question_id = None
                            if '_processed.jsonl' in filename:
                                try:
                                    question_id = int(filename.replace('_processed.jsonl', ''))
                                except:
                                    question_id = filename
                            else:
                                question_id = filename
                    
                            # Get ground truth for this question
                            ground_truth = None
                            if ground_truth_map and question_id in ground_truth_map:
                                ground_truth = ground_truth_map[question_id]
                    
                            # Collect traces from all directories for this filename
                            merged_traces = []
                            dirs_with_file = dir_file_mapping[filename]
                    
                            for trace_dir in dirs_with_file:
                                file_path = os.path.join(trace_dir, filename)
                                if os.path.exists(file_path):
                                    try:
                                        traces = process_jsonl_file(file_path, ground_truth)
                    
                                        # Add directory info to each trace for identification
                                        for i, trace in enumerate(traces):
                                            trace['source_dir'] = trace_dir
                                            trace['source_file'] = filename
                                            # Create unique trace ID combining dir and original trace ID
                                            original_trace_id = trace.get('trace_id', i)
                                            trace['trace_id'] = f"{os.path.basename(trace_dir)}_{original_trace_id}"
                    
                                        merged_traces.extend(traces)
                    
                                    except Exception as e:
                                        print(f"Error processing {file_path}: {e}")
                                        continue
                    
                            if merged_traces:
                                all_data[question_id] = merged_traces
                                print(f"Question {question_id}: Merged {len(merged_traces)} traces from {len(dirs_with_file)} directories")
                    
                        return dict(all_data)
                    
                    # ============================================
                    # SECTION 6: Analysis and Evaluation
                    # ============================================
                    
                    def analyze_voting_performance(data, voting_sizes=[1, 2, 4, 8, 16, 32],
                                                  strategy='majority', weight_key='mean_confidence',
                                                  n_trials=1, seed=42, top_percent=None):
                        """Analyze voting performance across different ensemble sizes"""
                    
                        random.seed(seed)
                        np.random.seed(seed)
                    
                        results = {}
                        for vote_size in voting_sizes:
                            accuracies = []
                    
                            for trial in range(n_trials):
                                correct = 0
                                total = 0
                    
                                for question_id, traces in data.items():
                                    if len(traces) < vote_size:
                                        continue
                    
                                    # Sample traces
                                    sampled = random.sample(traces, vote_size)
                    
                                    # Apply voting strategy
                                    if strategy == 'majority':
                                        voted_answer, _ = majority_vote(sampled)
                                    elif strategy == 'weighted':
                                        voted_answer, _ = weighted_majority_vote(sampled, weight_key)
                                    elif strategy == 'top_percent':
                                        voted_answer, _ = top_percent_vote(sampled, weight_key, top_percent, 'majority')
                                    elif strategy == 'top_percent_weighted':
                                        voted_answer, _ = top_percent_vote(sampled, weight_key, top_percent, 'weighted')
                                    else:
                                        raise ValueError(f"Unknown strategy: {strategy}")
                    
                                    # Check correctness
                                    if voted_answer is not None and sampled[0]['ground_truth'] is not None:
                                        ground_truth = sampled[0]['ground_truth']
                                        if math_equal(voted_answer, ground_truth):
                                            correct += 1
                                        total += 1
                    
                                if total > 0:
                                    accuracies.append(correct / total)
                    
                            if accuracies:
                                results[vote_size] = {
                                    'accuracy_mean': np.mean(accuracies),
                                    'accuracy_std': np.std(accuracies)
                                }
                    
                        return results
                    
                    def analyze_top_percent_strategies(data, voting_sizes=[1, 2, 4, 8],
                                                     weight_keys=['mean_confidence', 'tail_2048_mean_conf'],
                                                     top_percents=[0.1, 0.2, 0.3, 0.5],
                                                     n_trials=1, seed=42):
                        """
                        Comprehensive analysis of top percent filtering strategies
                    
                        Args:
                            data: Processed trace data
                            voting_sizes: List of ensemble sizes to test
                            weight_keys: List of confidence metrics to use for filtering
                            top_percents: List of top percentages to test (e.g., [0.1, 0.2] for top 10%, 20%)
                            n_trials: Number of random trials per configuration
                            seed: Random seed for reproducibility
                    
                        Returns:
                            Dictionary with results for each configuration
                        """
                    
                        results = {}
                    
                        # Test each combination of parameters
                        for weight_key in weight_keys:
                            for top_percent in top_percents:
                                for vote_strategy in ['weighted']:
                                    strategy_name = f"top_{int(top_percent*100)}%_{vote_strategy}_{weight_key}"
                    
                                    print(f"Testing {strategy_name}...")
                    
                                    strategy = 'top_percent' if vote_strategy == 'majority' else 'top_percent_weighted'
                    
                                    strategy_results = analyze_voting_performance(
                                        data,
                                        voting_sizes=voting_sizes,
                                        strategy=strategy,
                                        weight_key=weight_key,
                                        n_trials=n_trials,
                                        seed=seed,
                                        top_percent=top_percent
                                    )
                    
                                    results[strategy_name] = strategy_results
                    
                        return results
                    
                    def analyze_directory_distribution(data):
                        """Analyze the distribution of traces across source directories"""
                        print("\n" + "="*60)
                        print("DIRECTORY DISTRIBUTION ANALYSIS")
                        print("="*60)
                    
                        dir_stats = defaultdict(lambda: {'total_traces': 0, 'questions': set()})
                    
                        for question_id, traces in data.items():
                            for trace in traces:
                                source_dir = trace.get('source_dir', 'unknown')
                                dir_stats[source_dir]['total_traces'] += 1
                                dir_stats[source_dir]['questions'].add(question_id)
                    
                        print(f"{'Directory':<50} {'Traces':<10} {'Questions':<10}")
                        print("-" * 72)
                    
                        for dir_name, stats in dir_stats.items():
                            short_name = os.path.basename(dir_name) if dir_name != 'unknown' else 'unknown'
                            print(f"{short_name:<50} {stats['total_traces']:<10} {len(stats['questions']):<10}")
                    
                        return dir_stats
                    
                    # ============================================
                    # MAIN EXECUTION CELL
                    # ============================================
                    
                    def main_analysis_multi_dir(trace_dirs=None, file_pattern="*_processed.jsonl",
                                               ground_truth_file=None, output_dir="./results"):
                        """
                        Main analysis function for multiple directories
                    
                        Args:
                            trace_dirs: List of directory paths to search for JSONL files
                            file_pattern: File pattern to match (e.g., "*_processed.jsonl")
                            ground_truth_file: Optional pickle file with ground truth answers
                            output_dir: Directory to save results
                        """
                    
                        if trace_dirs is None:
                            trace_dirs = TRACE_DIRS
                    
                        print("="*60)
                        print("ENHANCED MULTI-DIRECTORY JSONL VOTING ANALYSIS")
                        print("="*60)
                        print(f"Processing {len(trace_dirs)} directories:")
                        for i, dir_path in enumerate(trace_dirs):
                            print(f"  {i+1}. {dir_path}")
                    
                        # Load ground truth if provided
                        ground_truth_map = None
                        if ground_truth_file and os.path.exists(ground_truth_file):
                            with open(ground_truth_file, 'rb') as f:
                                ground_truth_map = pickle.load(f)
                            print(f"Loaded ground truth from {ground_truth_file}")
                    
                        # Process JSONL files from multiple directories
                        print(f"\nProcessing files with pattern: {file_pattern}")
                        data = process_multiple_dirs_jsonls(trace_dirs, file_pattern, ground_truth_map)
                    
                        print(f"Processed {len(data)} questions")
                        total_traces = sum(len(traces) for traces in data.values())
                        print(f"Total traces: {total_traces}")
                    
                        # Analyze directory distribution
                        dir_stats = analyze_directory_distribution(data)
                    
                        # Calculate per-question statistics
                        print("\n" + "="*60)
                        print("PER-QUESTION STATISTICS")
                        print("="*60)
                    
                        question_items = list(data.items())
                        for q_id, traces in question_items[:5]:  # Show first 5 questions
                            correct = sum(1 for t in traces if t['is_correct'])
                            print(f"Question {q_id}: {correct}/{len(traces)} correct ({correct/len(traces):.2%})")
                            if traces:
                                mean_conf = np.mean([t.get('mean_confidence', 0) for t in traces if 'mean_confidence' in t])
                                print(f"  Mean confidence: {mean_conf:.4f}")
                    
                                # Show directory breakdown
                                dir_breakdown = defaultdict(int)
                                for trace in traces:
                                    dir_name = os.path.basename(trace.get('source_dir', 'unknown'))
                                    dir_breakdown[dir_name] += 1
                                print(f"  Directory breakdown: {dict(dir_breakdown)}")
                    
                        # Test baseline strategies
                        print("\n" + "="*60)
                        print("BASELINE VOTING STRATEGIES")
                        print("="*60)
                    
                        baseline_strategies = [
                            ('majority', 'Majority Vote', None),
                            ('weighted', 'Weighted Vote (mean_conf)', 'mean_confidence'),
                            ('weighted', 'Weighted Vote (tail_2048)', 'tail_2048_mean_conf'),
                    
                        ]
                    
                        voting_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
                    
                        all_results = {}
                    
                        for strategy, name, weight_key in baseline_strategies:
                            print(f"\n{name}:")
                            results = analyze_voting_performance(
                                data,
                                voting_sizes=voting_sizes,
                                strategy=strategy,
                                weight_key=weight_key,
                                n_trials=10
                            )
                    
                            all_results[name] = results
                    
                            print(f"{'Size':<6} {'Accuracy':<12} {'Std Dev':<10}")
                            print("-" * 30)
                            for size in voting_sizes:
                                if size in results:
                                    acc = results[size]['accuracy_mean']
                                    std = results[size]['accuracy_std']
                                    print(f"{size:<6} {acc:<12.4f} {std:<10.4f}")
                    
                        # Test top percent filtering strategies
                        print("\n" + "="*60)
                        print("TOP PERCENT FILTERING STRATEGIES")
                        print("="*60)
                    
                        top_percent_results = analyze_top_percent_strategies(
                            data,
                            voting_sizes=voting_sizes,  # Use smaller sizes for top percent
                            weight_keys=['mean_confidence', 'tail_2048_mean_conf', 'bottom_0.1_sliding_2048_mean_conf'],
                            top_percents=[0.1, 0.9],
                            n_trials=10
                        )
                    
                        all_results.update(top_percent_results)
                    
                        # Display top percent results
                        for strategy_name, strategy_results in top_percent_results.items():
                            print(f"\n{strategy_name}:")
                            print(f"{'Size':<6} {'Accuracy':<12} {'Std Dev':<10}")
                            print("-" * 30)
                            for size in voting_sizes:
                                if size in strategy_results:
                                    acc = strategy_results[size]['accuracy_mean']
                                    std = strategy_results[size]['accuracy_std']
                                    print(f"{size:<6} {acc:<12.4f} {std:<10.4f}")
                    
                        # Save results
                        os.makedirs(output_dir, exist_ok=True)
                    
                        # Save processed data
                        data_path = os.path.join(output_dir, "processed_data_multi_dir.pkl")
                        with open(data_path, 'wb') as f:
                            pickle.dump(data, f)
                        print(f"\n✓ Saved processed data to {data_path}")
                    
                        # Save voting results
                        results_path = os.path.join(output_dir, "voting_results_multi_dir.pkl")
                        with open(results_path, 'wb') as f:
                            pickle.dump(all_results, f)
                        print(f"✓ Saved voting results to {results_path}")
                    
                        # Create comprehensive summary DataFrame
                        summary_data = []
                        for strategy_name, strategy_results in all_results.items():
                            for size, metrics in strategy_results.items():
                                summary_data.append({
                                    'Strategy': strategy_name,
                                    'Ensemble Size': size,
                                    'Accuracy': metrics['accuracy_mean'],
                                    'Std Dev': metrics['accuracy_std']
                                })
                    
                        df_summary = pd.DataFrame(summary_data)
                        csv_path = os.path.join(output_dir, "voting_summary_multi_dir.csv")
                        df_summary.to_csv(csv_path, index=False)
                        print(f"✓ Saved summary CSV to {csv_path}")
                    
                        # Find best performing strategies
                        print("\n" + "="*60)
                        print("BEST PERFORMING STRATEGIES")
                        print("="*60)
                    
                        # Group by ensemble size and find best accuracy for each
                        for size in voting_sizes:
                            size_results = df_summary[df_summary['Ensemble Size'] == size]
                            if not size_results.empty:
                                best_row = size_results.loc[size_results['Accuracy'].idxmax()]
                                print(f"Size {size}: {best_row['Strategy']} (Accuracy: {best_row['Accuracy']:.4f})")
                    
                        return data, all_results, df_summary, dir_stats
                    
                    # ============================================
                    # EXAMPLE USAGE
                    # ============================================
                    
                    if __name__ == "__main__":
                        # Example usage - modify these paths for your data
                    
                        # Use the predefined TRACE_DIRS or specify your own
                        custom_dirs = None  # Set to your list of directories if different from TRACE_DIRS
                    
                        file_pattern = "*_processed.jsonl"  # Adjust this pattern
                        ground_truth_file = None  # Optional: "./ground_truth.pkl"
                        output_directory = "./voting_results_multi_dir"
                    
                        # Run the analysis
                        data, results, summary_df, dir_stats = main_analysis_multi_dir(
                            trace_dirs=custom_dirs,  # Will use TRACE_DIRS if None
                            file_pattern=file_pattern,
                            ground_truth_file=ground_truth_file,
                            output_dir=output_directory
                        )
                    
                        # Display summary
                        print("\n" + "="*60)
                        print("FINAL SUMMARY")
                        print("="*60)
                        print(summary_df.to_string(index=False))
                    
                        # Show some example merged data
                        print("\n" + "="*60)
                        print("EXAMPLE MERGED DATA")
                        print("="*60)
                    
                        for q_id, traces in list(data.items())[:2]:  # Show first 2 questions
                            print(f"\nQuestion {q_id} ({len(traces)} traces):")
                            for trace in traces[:3]:  # Show first 3 traces
                                source_dir = os.path.basename(trace.get('source_dir', 'unknown'))
                                correct = trace.get('is_correct', False)
                                conf = trace.get('mean_confidence', 0)
                                answer = trace.get('extracted_answer', 'None')
                                print(f"  {source_dir}: {answer} (correct: {correct}, conf: {conf:.4f})") 

Example 3: Online Generation

python
                    import openai
                    import json
                    from tqdm import tqdm
                    import time
                    import os
                    from datetime import datetime
                    import numpy as np
                    from collections import Counter
                    from  dynasor.core.evaluator import math_equal
                    
                    # ===========================
                    # Configuration
                    # ===========================
                    MODEL_PATH = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
                    MAX_TOKENS = 64000
                    RID = 0
                    QID = 0
                    PORT = 8000
                    DATASET_FILE = "aime25.jsonl"
                    
                    # Online algorithm parameters
                    WARMUP_TRACES = 16
                    TOTAL_BUDGET = 32
                    CONFIDENCE_PERCENTILE = 90
                    WINDOW_SIZE = 2048
                    
                    # ===========================
                    # Answer Extraction Functions
                    # ===========================
                    def extract_answer(text):
                        """Extract boxed answer from text"""
                        if "boxed" in text:
                            ans = text.split("boxed")[-1]
                            if len(ans) == 0:
                                return ""
                            elif ans[0] == "{":
                                stack = 1
                                a = ""
                                for c in ans[1:]:
                                    if c == "{":
                                        stack += 1
                                        a += c
                                    elif c == "}":
                                        stack -= 1
                                        if stack == 0:
                                            break
                                        a += c
                                    else:
                                        a += c
                            else:
                                a = ans.split("$")[0].strip()
                            return a.strip()
                        return None
                    
                    def parse_func(s):
                        for f in [parse_latex, parse_expr, latex2sympy]:
                            try:
                                return f(s.replace("\\\\", "\\"))
                            except:
                                try:
                                    return f(s)
                                except:
                                    pass
                        return s
                    
                    
                    # ===========================
                    # Data Loading
                    # ===========================
                    def load_aime25_jsonl(file_path=DATASET_FILE):
                        """Load data from aime25.jsonl file"""
                        data = []
                        with open(file_path, 'r', encoding='utf-8') as file:
                            for line in file:
                                data.append(json.loads(line.strip()))
                        return data
                    
                    # ===========================
                    # Confidence Calculation
                    # ===========================
                    def compute_confidence(logprobs):
                        """Compute confidence score from logprobs."""
                        confs = []
                        for lp in logprobs:
                            confs.append(round(-sum([l.logprob for l in lp]) / len(lp), 3))
                        return confs
                    
                    def compute_least_grouped(confs, group_size=WINDOW_SIZE):
                        """Compute sliding window mean confidence with specified group size."""
                        if len(confs) < group_size:
                            return [sum(confs) / len(confs)] if confs else [0]
                    
                        sliding_means = []
                        for i in range(len(confs) - group_size + 1):
                            window = confs[i:i + group_size]
                            sliding_means.append(round(sum(window) / len(window), 3))
                    
                        return sliding_means
                    
                    # ===========================
                    # Voting Functions
                    # ===========================
                    def weighted_majority_vote(answers, weights):
                        """Perform weighted majority voting"""
                        if not answers:
                            return None
                    
                        answer_weights = {}
                        for answer, weight in zip(answers, weights):
                            if answer is not None:
                                answer_str = str(answer)
                                answer_weights[answer_str] = answer_weights.get(answer_str, 0.0) + float(weight)
                    
                        if not answer_weights:
                            return None
                    
                        voted_answer = max(answer_weights.keys(), key=lambda x: answer_weights[x])
                        return voted_answer
                    
                    # ===========================
                    # Trace Processing and Accuracy Calculation
                    # ===========================
                    def process_trace(choice, trace_id, ground_truth):
                        """Process a single trace and calculate accuracy"""
                        # Extract basic info
                        text = choice.message.content
                        tokens = [t.token for t in choice.logprobs.content]
                    
                        # Calculate confidence
                        confs = compute_confidence([t.top_logprobs for t in choice.logprobs.content])
                        sliding_window = compute_least_grouped(confs, group_size=WINDOW_SIZE)
                    
                        # Extract and parse answer
                        extracted_answer = extract_answer(text)
                        parsed_answer = parse_func(extracted_answer) if extracted_answer else None
                    
                        # Calculate correctness
                        is_correct = False
                        if extracted_answer is not None and ground_truth is not None:
                            is_correct = math_equal(extracted_answer, ground_truth)
                    
                        trace_data = {
                            "trace_id": trace_id,
                            "stop_reason": choice.stop_reason,
                            "finish_reason": choice.finish_reason,
                            "text": text,
                            "tokens": tokens,
                            "token_count": len(tokens),
                            "confs": confs,
                            "group_confs": sliding_window,
                            "min_conf": min(sliding_window) if sliding_window else 0,
                            "extracted_answer": extracted_answer,
                            "parsed_answer": parsed_answer,
                            "is_correct": is_correct,
                        }
                    
                        return trace_data
                    
                    def calculate_statistics(traces, phase_name=""):
                        """Calculate statistics for a list of traces"""
                        if not traces:
                            return {}
                    
                        total_traces = len(traces)
                        correct_traces = sum(1 for t in traces if t['is_correct'])
                        total_tokens = sum(t['token_count'] for t in traces)
                    
                        # Confidence statistics
                        min_confs = [t['min_conf'] for t in traces]
                    
                        stats = {
                            f"{phase_name}_traces": total_traces,
                            f"{phase_name}_correct": correct_traces,
                            f"{phase_name}_accuracy": correct_traces / total_traces if total_traces > 0 else 0,
                            f"{phase_name}_total_tokens": total_tokens,
                            f"{phase_name}_avg_tokens_per_trace": total_tokens / total_traces if total_traces > 0 else 0,
                            f"{phase_name}_min_conf_mean": np.mean(min_confs) if min_confs else 0,
                            f"{phase_name}_min_conf_std": np.std(min_confs) if min_confs else 0,
                        }
                    
                        return stats
                    
                    # ===========================
                    # Problem Processing Function (like problemx in reference)
                    # ===========================
                    def process_problem_voting(test_json, ground_truth):
                        """
                        Process a problem result JSON and perform voting
                        Similar to problemx function in reference code
                        """
                        answers = []
                        bar = test_json['conf_bar']
                        weights = []
                        tokens = 0
                    
                        print(f"Warmup traces: {len(test_json['warmup_traces'])}, Final traces: {len(test_json['final_traces'])}")
                    
                        # Process warmup traces
                        for i in range(len(test_json['warmup_traces'])):
                            answer = extract_answer(test_json['warmup_traces'][i]['text'])
                            minx = min(test_json['warmup_traces'][i]['group_confs'])
                            tokens += len(test_json['warmup_traces'][i]['tokens'])
                            if minx < bar:
                                continue
                            if answer is not None:
                                answers.append(answer)
                                weights.append(1)  # Use weight 1 for consistency with reference
                    
                        # Process final traces
                        for i in range(len(test_json['final_traces'])):
                            tokens += len(test_json['final_traces'][i]['tokens'])
                            # Skip traces stopped by gconf
                            if test_json['final_traces'][i]['stop_reason'] is not None and 'gconf' in test_json['final_traces'][i]['stop_reason']:
                                continue
                            answer = extract_answer(test_json['final_traces'][i]['text'])
                            minx = min(test_json['final_traces'][i]['group_confs'])
                            if answer is not None:
                                answers.append(answer)
                                weights.append(1)
                    
                        # Perform voting
                        voted = weighted_majority_vote(answers, weights)
                        is_correct = str(voted) == str(ground_truth) if voted is not None else False
                    
                        print(f'Bar: {bar:.4f}, Voted: {voted}, Ground truth: {ground_truth}, Correct: {is_correct}, Voting answers: {len(answers)}')
                    
                        return is_correct, len(answers), tokens
                    
                    # ===========================
                    # Main Function
                    # ===========================
                    def main():
                        print("="*60)
                        print("ONLINE ALGORITHM WITH ACCURACY CALCULATION")
                        print("="*60)
                        print(f"Model: {MODEL_PATH}")
                        print(f"Question ID: {QID}")
                        print(f"Run ID: {RID}")
                        print(f"Warmup traces: {WARMUP_TRACES}")
                        print(f"Total budget: {TOTAL_BUDGET}")
                        print(f"Confidence percentile: {CONFIDENCE_PERCENTILE}")
                    
                        # Load the data
                        data = load_aime25_jsonl()
                        print(f"Loaded {len(data)} items from {DATASET_FILE}")
                    
                        # Initialize client
                        client = openai.OpenAI(
                            api_key="None",
                            base_url=f"http://localhost:{PORT}/v1",
                            timeout=None
                        )
                    
                        # Get question and ground truth
                        prompt = data[QID]['question']
                        ground_truth = str(data[QID].get('answer', '')).strip()
                    
                        messages = [
                            {"role": "system", "content": "该助手为DeepSeek-R1,由深度求索公司创造。\n今天是2025年5月28日,星期一。\n"},
                            {"role": "user", "content": prompt}
                        ]
                    
                        print(f"\nQuestion: {prompt}")
                        print(f"Ground Truth: {ground_truth}")
                    
                        # ===========================
                        # WARMUP PHASE
                        # ===========================
                        print(f"\n{'-'*40}")
                        print("WARMUP PHASE")
                        print(f"{'-'*40}")
                    
                        t0 = time.time()
                        responses = client.chat.completions.create(
                            model=MODEL_PATH,
                            messages=messages,
                            max_tokens=MAX_TOKENS,
                            temperature=0.6,
                            top_p=0.95,
                            logprobs=True,
                            top_logprobs=20,
                            n=WARMUP_TRACES,
                            extra_body={"top_k": 0},
                        )
                        t1 = time.time()
                    
                        # Process warmup traces
                        warmup_traces = []
                        min_confs = []
                    
                        for j in range(WARMUP_TRACES):
                            trace_data = process_trace(responses.choices[j], j, ground_truth)
                            warmup_traces.append(trace_data)
                            min_confs.append(trace_data["min_conf"])
                    
                        # Calculate confidence bar
                        conf_bar = float(np.percentile(min_confs, CONFIDENCE_PERCENTILE))
                    
                        # Calculate warmup statistics
                        warmup_stats = calculate_statistics(warmup_traces, "warmup")
                    
                        print(f"Warmup time: {t1 - t0:.2f}s")
                        print(f"Confidence bar (P{CONFIDENCE_PERCENTILE}): {conf_bar:.4f}")
                        print(f"Warmup accuracy: {warmup_stats['warmup_accuracy']:.4f} ({warmup_stats['warmup_correct']}/{warmup_stats['warmup_traces']})")
                        print(f"Warmup total tokens: {warmup_stats['warmup_total_tokens']}")
                        print(f"Warmup avg tokens per trace: {warmup_stats['warmup_avg_tokens_per_trace']:.1f}")
                    
                        # Show some example results
                        print(f"\nFirst 3 warmup traces:")
                        for i, trace in enumerate(warmup_traces[:3]):
                            print(f"  Trace {i}: {trace['extracted_answer']} (correct: {trace['is_correct']}, conf: {trace['min_conf']:.4f}, tokens: {trace['token_count']})")
                    
                        # ===========================
                        # FINAL PHASE
                        # ===========================
                        print(f"\n{'-'*40}")
                        print("FINAL PHASE (with early stopping)")
                        print(f"{'-'*40}")
                    
                        real_gen = TOTAL_BUDGET - WARMUP_TRACES
                    
                        t3 = time.time()
                        responses = client.chat.completions.create(
                            model=MODEL_PATH,
                            messages=messages,
                            max_tokens=MAX_TOKENS,
                            temperature=0.6,
                            top_p=0.95,
                            logprobs=True,
                            top_logprobs=20,
                            n=real_gen,
                            extra_body={
                                "top_k": 0,
                                "vllm_xargs": {
                                    'enable_conf': True,
                                    'window_size': WINDOW_SIZE,
                                    'threshold': conf_bar
                                }
                            }
                        )
                        t4 = time.time()
                    
                        # Process final traces
                        final_traces = []
                        for j in range(len(responses.choices)):
                            trace_data = process_trace(responses.choices[j], WARMUP_TRACES + j, ground_truth)
                            final_traces.append(trace_data)
                    
                        # Calculate final statistics
                        final_stats = calculate_statistics(final_traces, "final")
                    
                        print(f"Final time: {t4 - t3:.2f}s")
                        print(f"Final traces generated: {len(final_traces)} (requested: {real_gen})")
                        print(f"Final accuracy: {final_stats['final_accuracy']:.4f} ({final_stats['final_correct']}/{final_stats['final_traces']})")
                        print(f"Final total tokens: {final_stats['final_total_tokens']}")
                        print(f"Final avg tokens per trace: {final_stats['final_avg_tokens_per_trace']:.1f}")
                    
                        # Show some example results
                        print(f"\nFirst 3 final traces:")
                        for i, trace in enumerate(final_traces[:3]):
                            print(f"  Trace {i}: {trace['extracted_answer']} (correct: {trace['is_correct']}, conf: {trace['min_conf']:.4f}, tokens: {trace['token_count']})")
                    
                        # ===========================
                        # OVERALL STATISTICS
                        # ===========================
                        print(f"\n{'-'*40}")
                        print("OVERALL STATISTICS")
                        print(f"{'-'*40}")
                    
                        all_traces = warmup_traces + final_traces
                        overall_stats = calculate_statistics(all_traces, "overall")
                    
                        total_time = t4 - t0
                        warmup_time = t1 - t0
                        final_time = t4 - t3
                    
                        print(f"Total time: {total_time:.2f}s (warmup: {warmup_time:.2f}s, final: {final_time:.2f}s)")
                        print(f"Overall traces: {len(all_traces)}")
                        print(f"Overall accuracy: {overall_stats['overall_accuracy']:.4f} ({overall_stats['overall_correct']}/{overall_stats['overall_traces']})")
                        print(f"Overall total tokens: {overall_stats['overall_total_tokens']}")
                        print(f"Overall avg tokens per trace: {overall_stats['overall_avg_tokens_per_trace']:.1f}")
                    
                        # Efficiency metrics
                        tokens_per_second = overall_stats['overall_total_tokens'] / total_time
                        traces_per_second = len(all_traces) / total_time
                    
                        print(f"Tokens per second: {tokens_per_second:.1f}")
                        print(f"Traces per second: {traces_per_second:.2f}")
                    
                        # Early stopping analysis
                        if final_traces:
                            above_threshold = sum(1 for t in final_traces if t['min_conf'] >= conf_bar)
                            print(f"Final traces above threshold: {above_threshold}/{len(final_traces)} ({above_threshold/len(final_traces):.2%})")
                    
                        # ===========================
                        # SAVE RESULTS
                        # ===========================
                        results = {
                            "question_id": QID,
                            "run_id": RID,
                            "question": prompt,
                            "ground_truth": ground_truth,
                            "conf_bar": conf_bar,
                            "warmup_traces": warmup_traces,
                            "final_traces": final_traces,
                            "statistics": {
                                **warmup_stats,
                                **final_stats,
                                **overall_stats,
                                "total_time": total_time,
                                "warmup_time": warmup_time,
                                "final_time": final_time,
                                "tokens_per_second": tokens_per_second,
                                "traces_per_second": traces_per_second,
                            },
                            "config": {
                                "model_path": MODEL_PATH,
                                "warmup_traces": WARMUP_TRACES,
                                "total_budget": TOTAL_BUDGET,
                                "confidence_percentile": CONFIDENCE_PERCENTILE,
                                "window_size": WINDOW_SIZE,
                            },
                            "timestamp": datetime.now().isoformat(),
                        }
                    
                        # Save results to JSON file
                        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                        os.makedirs("outputs", exist_ok=True)
                        pickle.dump(results, open(f"outputs/q{QID}_r{RID}_{timestamp}.pkl", 'wb'))
                        print(f"Results saved to outputs/q{QID}_r{RID}_{timestamp}.pkl")
                        output_filename = f"outputs/q{QID}_r{RID}_{timestamp}.pkl"
                    
                        print(f"\n✅ Results saved to outputs/q{QID}_r{RID}_{timestamp}.pkl")
                    
                        # ===========================
                        # COMPARISON WITH BASELINE
                        # ===========================
                        print(f"\n{'-'*40}")
                        print("COMPARISON ANALYSIS")
                        print(f"{'-'*40}")
                    
                        # Compare warmup vs final phase
                        if warmup_stats['warmup_traces'] > 0 and final_stats['final_traces'] > 0:
                            acc_improvement = final_stats['final_accuracy'] - warmup_stats['warmup_accuracy']
                            print(f"Accuracy change (final - warmup): {acc_improvement:+.4f}")
                    
                            token_efficiency = final_stats['final_avg_tokens_per_trace'] / warmup_stats['warmup_avg_tokens_per_trace']
                            print(f"Token efficiency (final/warmup): {token_efficiency:.2f}x")
                    
                        # Early stopping effectiveness
                        total_requested = WARMUP_TRACES + (TOTAL_BUDGET - WARMUP_TRACES)
                        total_generated = len(all_traces)
                        saving_ratio = (total_requested - total_generated) / total_requested
                    
                        print(f"Traces requested: {total_requested}")
                        print(f"Traces generated: {total_generated}")
                        print(f"Early stopping saving: {saving_ratio:.2%}")
                    
                        # ===========================
                        # VOTING FOR FINAL ANSWER
                        # ===========================
                        print(f"\n{'-'*40}")
                        print("VOTING FOR FINAL ANSWER")
                        print(f"{'-'*40}")
                    
                        # Collect answers above threshold for voting
                        voting_answers = []
                        voting_weights = []
                        total_voting_tokens = 0
                    
                        # Add warmup traces above threshold (use confidence as weight)
                        for trace in warmup_traces:
                            minx = trace['min_conf']
                            total_voting_tokens += trace['token_count']
                            if minx >= conf_bar:
                                answer = trace['extracted_answer']
                                if answer is not None:
                                    voting_answers.append(answer)
                                    voting_weights.append(minx)
                    
                        # Add final traces above threshold (use weight 1)
                        for trace in final_traces:
                            total_voting_tokens += trace['token_count']
                            # Skip traces that were stopped by gconf (early stopping)
                            if trace['stop_reason'] is not None and 'gconf' in trace['stop_reason']:
                                continue
                    
                            minx = trace['min_conf']
                            # Note: final traces might not need threshold check since they were already filtered
                            # But keeping it consistent with the reference code
                            if minx >= conf_bar:
                                answer = trace['extracted_answer']
                                if answer is not None:
                                    voting_answers.append(answer)
                                    voting_weights.append(1)
                    
                        print(f"Traces used for voting: {len(voting_answers)}")
                        print(f"Total tokens in voting traces: {total_voting_tokens}")
                    
                        # Perform weighted majority vote
                        voted_answer = weighted_majority_vote(voting_answers, voting_weights)
                    
                        # Check if voted answer is correct
                        is_voted_correct = False
                        if voted_answer is not None and ground_truth:
                            is_voted_correct = str(voted_answer) == str(ground_truth)
                            # Also try math_equal for more robust comparison
                            try:
                                is_voted_correct_math = math_equal(voted_answer, ground_truth)
                                if is_voted_correct != is_voted_correct_math:
                                    print(f"Warning: String comparison ({is_voted_correct}) != math comparison ({is_voted_correct_math})")
                                    is_voted_correct = is_voted_correct_math  # Use math_equal as more reliable
                            except:
                                pass  # Fallback to string comparison
                    
                        print(f"Voted answer: {voted_answer}")
                        print(f"Ground truth: {ground_truth}")
                        print(f"Voted answer correct: {is_voted_correct}")
                    
                        # Show voting breakdown
                        if voting_answers:
                            answer_counts = Counter(voting_answers)
                            print(f"\nVoting breakdown:")
                            for answer, count in answer_counts.most_common():
                                print(f"  {answer}: {count} votes")
                    
                        # ===========================
                        # UPDATE RESULTS WITH VOTING
                        # ===========================
                        voting_results = {
                            "voting_answers": voting_answers,
                            "voting_weights": voting_weights,
                            "voted_answer": voted_answer,
                            "is_voted_correct": is_voted_correct,
                            "voting_traces_count": len(voting_answers),
                            "voting_total_tokens": total_voting_tokens,
                        }
                    
                        results["voting"] = voting_results
                        results["statistics"]["voting_traces_count"] = len(voting_answers)
                        results["statistics"]["voting_total_tokens"] = total_voting_tokens
                        results["statistics"]["is_voted_correct"] = is_voted_correct
                    
                        # Update the saved file with voting results
                        # with open(output_filename, 'w', encoding='utf-8') as f:
                        #     pickle.dump(results, f)
                    
                        # print(f"\n✅ Results updated with voting information: {output_filename}")
                    
                        # ===========================
                        # FINAL SUMMARY
                        # ===========================
                        print(f"\n{'='*60}")
                        print("FINAL SUMMARY")
                        print(f"{'='*60}")
                        print(f"Question ID: {QID}")
                        print(f"Confidence bar: {conf_bar:.4f}")
                        print(f"Total traces generated: {len(all_traces)}")
                        print(f"Traces used for voting: {len(voting_answers)}")
                        print(f"Total tokens: {overall_stats['overall_total_tokens']}")
                        print(f"Voting answer: {voted_answer}")
                        print(f"Ground truth: {ground_truth}")
                        print(f"Final result: {'✅ CORRECT' if is_voted_correct else '❌ INCORRECT'}")
                    
                        return results
                    
                    if __name__ == "__main__":
                        results = main()
                

⚠️ Important Notes

  • Make sure vLLM server is running before executing the examples
  • Adjust the configuration parameters based on your specific requirements
  • The examples use the Dynasor library for evaluation - ensure it's properly installed
  • For production use, consider adding proper error handling and logging

Summary

This guide has provided you with:

  1. ✅ Complete instructions for modifying vLLM to implement DeepConf
  2. ✅ Alternative option to download the pre-implemented PR
  3. ✅ Three comprehensive examples for different use cases:
    • Offline Generation with parallel processing
    • Offline Voting with multiple strategies
    • Online Generation with confidence-based early stopping

For more information and updates, visit the official PR on GitHub.