#!/usr/bin/env python3 """ MTP Benchmark Runner — Automated Multi-Token Prediction benchmarking. Auto-discovers MTP GGUF models, starts llama-server in podman containers (with and without MTP), runs mtp-bench.py against each configuration, and collects structured JSON results. Usage: python run_mtp_bench.py # run everything python run_mtp_bench.py --model "Qwen3.6-35B" # filter by model name python run_mtp_bench.py --toolbox vulkan-radv-mtp # filter by toolbox python run_mtp_bench.py --models-dir /path/to/models python run_mtp_bench.py --port 8081 """ import argparse import datetime import json import os import platform import subprocess import sys import time from pathlib import Path from urllib import request from urllib.error import URLError # ── Toolbox definitions ────────────────────────────────────────────────────── TOOLBOXES = { "rocm-7.2.3": { "image": "docker.io/kyuz0/amd-strix-halo-toolboxes:rocm-7.2.3", "engine_args": [ "--device", "/dev/dri", "--device", "/dev/kfd", "--group-add", "video", "--group-add", "render", "--security-opt", "seccomp=unconfined", ], }, "vulkan-radv": { "image": "docker.io/kyuz0/amd-strix-halo-toolboxes:vulkan-radv", "engine_args": [ "--device", "/dev/dri", "--group-add", "video", "--security-opt", "seccomp=unconfined", ], }, } # ── Benchmark modes ────────────────────────────────────────────────────────── MODES = { "baseline": [], "mtp-2": ["--spec-type", "draft-mtp", "--spec-draft-n-max", "2", "-np", "1"], "mtp-3": ["--spec-type", "draft-mtp", "--spec-draft-n-max", "3", "-np", "1"], } # ── Constants ──────────────────────────────────────────────────────────────── CONTAINER_NAME = "mtp-bench-server" HEALTH_TIMEOUT = 180 # seconds to wait for server readiness HEALTH_INTERVAL = 3 # seconds between health polls COOLDOWN = 5 # seconds between runs BENCH_SCRIPT = Path(__file__).parent / "mtp-bench.py" def _sigint_handler(sig, frame): """Clean up container on Ctrl+C.""" print("\n\n🛑 Interrupted — cleaning up container...") subprocess.run(["podman", "stop", "-t", "3", CONTAINER_NAME], capture_output=True) subprocess.run(["podman", "rm", "-f", CONTAINER_NAME], capture_output=True) sys.exit(1) import signal signal.signal(signal.SIGINT, _sigint_handler) # ── Model discovery ────────────────────────────────────────────────────────── def discover_models(models_dir: Path) -> list[dict]: """Scan models_dir for GGUF files with 'MTP' in their path.""" models = [] if not models_dir.is_dir(): print(f"❌ Models directory not found: {models_dir}") sys.exit(1) for gguf in sorted(models_dir.rglob("*.gguf")): rel = gguf.relative_to(models_dir) # Must have MTP somewhere in the path (case-insensitive) if "mtp" not in str(rel).lower(): continue # Skip non-first shards of multi-shard models name = gguf.name if "-000" in name and "-00001-of-" not in name: continue # Always use the actual filename (without .gguf) to capture the quant level display_name = gguf.stem if "-00001-of-" in display_name: display_name = display_name.split("-00001-of-")[0] models.append({ "name": display_name, "gguf": str(rel), }) return models # ── System info ────────────────────────────────────────────────────────────── def capture_system_info(results_dir: Path): """Write system_info.json if it doesn't exist.""" path = results_dir / "system_info.json" if path.exists(): return def get_distro(): try: with open("/etc/os-release") as f: for line in f: if line.startswith("PRETTY_NAME="): return line.split("=", 1)[1].strip().strip('"') except Exception: pass return "Linux" def get_linux_firmware(): try: r = subprocess.run(["rpm", "-q", "linux-firmware"], capture_output=True, text=True) if r.returncode == 0: return r.stdout.strip() except Exception: pass return "unknown" info = { "distro": get_distro(), "kernel": platform.release(), "linux_firmware": get_linux_firmware(), "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } path.write_text(json.dumps(info, indent=2) + "\n") print(f"📋 Captured system info → {path}") # ── Cleanup & container lifecycle ──────────────────────────────────────────── def check_port_free(port: int) -> bool: """Check if a port is free. If not, identify and report what's using it.""" import socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: s.bind(("127.0.0.1", port)) return True except OSError: return False def kill_port_holder(port: int): """Find and kill whatever process is holding the port.""" result = subprocess.run( ["fuser", f"{port}/tcp"], capture_output=True, text=True, ) pids = result.stdout.strip().split() if pids: print(f" ⚠️ Port {port} held by PID(s): {', '.join(pids)} — killing...") for pid in pids: pid = pid.strip() if pid.isdigit(): subprocess.run(["kill", "-9", pid], capture_output=True) time.sleep(2) else: # fuser didn't find it, try ss as fallback result = subprocess.run( ["ss", "-tlnp", f"sport = :{port}"], capture_output=True, text=True, ) if result.stdout.strip(): print(f" ⚠️ Port {port} is in use (could not identify PID):") print(f" {result.stdout.strip()}") def cleanup(port: int) -> bool: """Full cleanup: stop stale containers and wait for port to be free.""" stop_container() # Wait for port to be released (podman proxy takes a moment after stop) for i in range(15): # up to 15 seconds if check_port_free(port): return True if i == 0: print(f" ⏳ Waiting for port {port} to be released...") time.sleep(1) print(f" ❌ Port {port} still in use after 15s. Something else is holding it.") return False def stop_container(): """Stop and remove the benchmark container if it exists.""" subprocess.run(["podman", "stop", "-t", "5", CONTAINER_NAME], capture_output=True) subprocess.run(["podman", "rm", "-f", CONTAINER_NAME], capture_output=True) def start_server(toolbox: dict, gguf: str, models_dir: Path, spec_flags: list[str], port: int) -> bool: """Start llama-server in a podman container. Returns True on success.""" if not cleanup(port): return False cmd = [ "podman", "run", "--rm", "--replace", "-d", "--name", CONTAINER_NAME, "--security-opt", "label=disable", "--userns=keep-id", *toolbox["engine_args"], "-v", f"{models_dir}:/models:ro", "-p", f"127.0.0.1:{port}:{port}", toolbox["image"], "llama-server", "-m", f"/models/{gguf}", "-c", "12288", "-ngl", "999", "--host", "0.0.0.0", "--port", str(port), "--no-mmap", "-fa", "1", "--jinja", *spec_flags, ] print(f" 🐳 Starting container...") print(f" {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: print(f" ❌ Failed to start container: {result.stderr.strip()}") return False return True def is_container_alive() -> bool: """Check if the benchmark container is still running.""" result = subprocess.run( ["podman", "inspect", "--format", "{{.State.Status}}", CONTAINER_NAME], capture_output=True, text=True, ) return result.returncode == 0 and result.stdout.strip() == "running" def get_last_log_line() -> str: """Get the last meaningful log line from the container.""" result = subprocess.run( ["podman", "logs", "--tail", "3", CONTAINER_NAME], capture_output=True, text=True, ) # Merge stdout+stderr (llama-server logs to stderr) lines = (result.stdout + result.stderr).strip().split("\n") # Return last non-empty line for line in reversed(lines): line = line.strip() if line: return line[:120] # truncate for display return "" def wait_for_health(port: int) -> bool: """Poll /health until the server is ready. Detects dead containers.""" url = f"http://127.0.0.1:{port}/health" deadline = time.time() + HEALTH_TIMEOUT last_status_time = 0 status_interval = 15 # show log progress every 15s polls = 0 print(f" ⏳ Waiting for server health ({HEALTH_TIMEOUT}s timeout)...") # Brief initial wait for container to start time.sleep(3) while time.time() < deadline: polls += 1 # Check if container died (or was auto-removed by --rm) if not is_container_alive(): print(f" ❌ Container exited — model likely failed to load.") # Try to get logs (won't work if --rm already cleaned up) logs = subprocess.run( ["podman", "logs", "--tail", "20", CONTAINER_NAME], capture_output=True, text=True, ) output = (logs.stdout + logs.stderr).strip() if output and "no container with name" not in output: print(f" 📝 Container logs:") for line in output.split("\n")[-10:]: print(f" {line}") else: print(f" 💡 Container was auto-removed. To see the error, run manually:") print(f" podman run --name debug-server ... (without -d)") return False # Try health endpoint try: req = request.Request(url, method="GET") with request.urlopen(req, timeout=5) as r: if r.status == 200: data = json.loads(r.read()) if data.get("status") == "ok": print(f" ✅ Server ready!") return True except (URLError, OSError, json.JSONDecodeError): pass # Periodically show what the server is doing now = time.time() if now - last_status_time >= status_interval: last_status_time = now elapsed = int(now - (deadline - HEALTH_TIMEOUT)) log_line = get_last_log_line() if log_line: print(f" [{elapsed}s] {log_line}") else: print(f" [{elapsed}s] (waiting...)") time.sleep(HEALTH_INTERVAL) print(f" ❌ Health check timeout after {HEALTH_TIMEOUT}s") # Dump container logs for debugging logs = subprocess.run(["podman", "logs", "--tail", "30", CONTAINER_NAME], capture_output=True, text=True) output = (logs.stdout + logs.stderr).strip() if output: print(f" 📝 Last lines of server logs:") for line in output.split("\n")[-15:]: print(f" {line}") return False # ── Benchmark execution ────────────────────────────────────────────────────── def run_benchmark(port: int, out_path: Path) -> dict | None: """Run mtp-bench.py as subprocess. Returns parsed JSON or None.""" cmd = [ sys.executable, str(BENCH_SCRIPT), "--url", f"http://127.0.0.1:{port}", "--out", str(out_path), ] print(f" 🔬 Running benchmark...") result = subprocess.run(cmd, capture_output=False) if result.returncode != 0: print(f" ❌ Benchmark failed (exit {result.returncode})") return None try: return json.loads(out_path.read_text()) except Exception as e: print(f" ❌ Failed to read results: {e}") return None # ── Results handling ───────────────────────────────────────────────────────── def save_result(out_path: Path, model: dict, toolbox_name: str, mode: str, spec_flags: list[str], bench_data: dict): """Wrap mtp-bench output with run metadata and save.""" wrapped = { "model": model["name"], "gguf": model["gguf"], "toolbox": toolbox_name, "mode": mode, "spec_flags": " ".join(spec_flags) if spec_flags else "(none)", "timestamp": datetime.datetime.now().isoformat(), **bench_data, } out_path.write_text(json.dumps(wrapped, indent=2) + "\n") def print_summary(results_dir: Path): """Read all result JSONs and print a summary table.""" results = [] for f in sorted(results_dir.glob("*.json")): if f.name == "system_info.json" or f.name == "summary.json": continue try: data = json.loads(f.read_text()) if "aggregate" in data: results.append(data) except Exception: continue if not results: print("\n📊 No results found.") return # Compute average tok/s per result for r in results: prompts = r.get("results", []) if prompts: r["_avg_toks"] = sum(p.get("predicted_per_second", 0) for p in prompts) / len(prompts) else: r["_avg_toks"] = 0 # Build baseline lookup for speedup calculation baselines = {} for r in results: if r["mode"] == "baseline": key = (r["model"], r["toolbox"]) baselines[key] = r["_avg_toks"] # Print table print("\n" + "=" * 115) print(f"{'Model':<30} {'Toolbox':<20} {'Mode':<10} {'Prefill pt/s':>13} {'Avg tok/s':>10} {'Accept%':>9} {'Wall(s)':>8} {'Speedup':>8}") print("-" * 115) for r in results: agg = r.get("aggregate", {}) accept = agg.get("aggregate_accept_rate") wall = agg.get("wall_s_total", 0) accept_str = f"{accept * 100:.1f}%" if accept is not None else "—" avg_toks = r["_avg_toks"] avg_prompt = agg.get("avg_prompt_per_second") prefill_str = f"{avg_prompt:.1f}" if avg_prompt is not None else "—" # Speedup relative to baseline baseline_key = (r["model"], r["toolbox"]) baseline_toks = baselines.get(baseline_key) if baseline_toks and baseline_toks > 0: speedup = f"{avg_toks / baseline_toks:.2f}×" else: speedup = "—" print(f"{r['model']:<30} {r['toolbox']:<20} {r['mode']:<10} {prefill_str:>13} {avg_toks:>10.1f} {accept_str:>9} {wall:>8.1f} {speedup:>8}") print("=" * 115) # Write summary.json summary_data = [] for r in results: agg = r.get("aggregate", {}) summary_data.append({ "model": r["model"], "toolbox": r["toolbox"], "mode": r["mode"], "avg_prompt_tok_s": agg.get("avg_prompt_per_second"), "avg_tok_s": round(r["_avg_toks"], 1), "accept_rate": agg.get("aggregate_accept_rate"), "wall_s_total": agg.get("wall_s_total"), "results": r.get("results", []) }) summary_path = results_dir / "summary.json" summary_path.write_text(json.dumps(summary_data, indent=2) + "\n") print(f"\n📄 Summary written to {summary_path}") # ── Main ───────────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser(description="MTP Benchmark Runner") ap.add_argument("--models-dir", type=Path, default=Path.home() / "models", help="Directory containing GGUF models (default: ~/models)") ap.add_argument("--model", type=str, default=None, help="Filter: only run models whose name contains this string") ap.add_argument("--toolbox", type=str, default=None, help="Filter: only run this toolbox (e.g. 'vulkan-radv-mtp')") ap.add_argument("--port", type=int, default=8080, help="Port for llama-server (default: 8080)") args = ap.parse_args() models_dir = args.models_dir.expanduser().resolve() results_dir = Path(__file__).parent / "results-mtp" results_dir.mkdir(exist_ok=True) # Discover models models = discover_models(models_dir) if args.model: models = [m for m in models if args.model.lower() in m["name"].lower()] if not models: print(f"❌ No MTP models found in {models_dir}") sys.exit(1) # Filter toolboxes toolboxes = TOOLBOXES if args.toolbox: if args.toolbox not in TOOLBOXES: print(f"❌ Unknown toolbox: {args.toolbox}") print(f" Available: {', '.join(TOOLBOXES.keys())}") sys.exit(1) toolboxes = {args.toolbox: TOOLBOXES[args.toolbox]} # Print run plan print(f"\n🔍 Discovered {len(models)} MTP model(s):") for m in models: print(f" • {m['name']} → {m['gguf']}") print(f"\n🧰 Toolboxes: {', '.join(toolboxes.keys())}") print(f"📊 Modes: {', '.join(MODES.keys())}") total = len(models) * len(toolboxes) * len(MODES) print(f"📋 Total runs: {total}\n") # Capture system info capture_system_info(results_dir) # Pre-flight cleanup print("🧹 Pre-flight cleanup...") if not cleanup(args.port): print("❌ Cannot free port — exiting.") sys.exit(1) print("✅ Environment clean.\n") # Run benchmarks run_count = 0 for tb_name, tb_config in toolboxes.items(): for model in models: for mode_name, spec_flags in MODES.items(): run_count += 1 out_file = results_dir / f"{model['name']}__{tb_name}__{mode_name}.json" print(f"\n{'─' * 80}") print(f"▶ [{run_count}/{total}] {model['name']} | {tb_name} | {mode_name}") print(f" Output: {out_file.name}") # Skip if results exist if out_file.exists(): print(f" ⏩ Skipping — results already exist") continue # Start server if not start_server(tb_config, model["gguf"], models_dir, spec_flags, args.port): stop_container() continue # Wait for health if not wait_for_health(args.port): stop_container() continue # Run benchmark bench_data = run_benchmark(args.port, out_file) # Stop server print(f" 🛑 Stopping container...") stop_container() # Wrap result with metadata if bench_data: save_result(out_file, model, tb_name, mode_name, spec_flags, bench_data) print(f" ✅ Done — saved to {out_file.name}") else: print(f" ❌ No results collected") # Cooldown if run_count < total: print(f" 💤 Cooldown ({COOLDOWN}s)...") time.sleep(COOLDOWN) # Print summary print_summary(results_dir) if __name__ == "__main__": main()