feat: add automated MTP benchmark runner for llama-server via podman containers
This commit is contained in:
Executable
+431
@@ -0,0 +1,431 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MTP Benchmark Runner — Automated Multi-Token Prediction benchmarking.
|
||||
|
||||
Auto-discovers MTP GGUF models, starts llama-server in podman containers
|
||||
(with and without MTP), runs mtp-bench.py against each configuration,
|
||||
and collects structured JSON results.
|
||||
|
||||
Usage:
|
||||
python run_mtp_bench.py # run everything
|
||||
python run_mtp_bench.py --model "Qwen3.6-35B" # filter by model name
|
||||
python run_mtp_bench.py --toolbox vulkan-radv-mtp # filter by toolbox
|
||||
python run_mtp_bench.py --models-dir /path/to/models
|
||||
python run_mtp_bench.py --port 8081
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from urllib import request
|
||||
from urllib.error import URLError
|
||||
|
||||
# ── Toolbox definitions ──────────────────────────────────────────────────────
|
||||
|
||||
TOOLBOXES = {
|
||||
"rocm-7.2.3-mtp": {
|
||||
"image": "docker.io/kyuz0/amd-strix-halo-toolboxes:rocm-7.2.3-mtp",
|
||||
"engine_args": [
|
||||
"--device", "/dev/dri",
|
||||
"--device", "/dev/kfd",
|
||||
"--group-add", "video",
|
||||
"--group-add", "render",
|
||||
"--security-opt", "seccomp=unconfined",
|
||||
],
|
||||
},
|
||||
"vulkan-radv-mtp": {
|
||||
"image": "docker.io/kyuz0/amd-strix-halo-toolboxes:vulkan-radv-mtp",
|
||||
"engine_args": [
|
||||
"--device", "/dev/dri",
|
||||
"--group-add", "video",
|
||||
"--security-opt", "seccomp=unconfined",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# ── Benchmark modes ──────────────────────────────────────────────────────────
|
||||
|
||||
MODES = {
|
||||
"baseline": [],
|
||||
"mtp-2": ["--spec-type", "draft-mtp", "--spec-draft-n-max", "2", "-np", "1"],
|
||||
"mtp-3": ["--spec-type", "draft-mtp", "--spec-draft-n-max", "3", "-np", "1"],
|
||||
}
|
||||
|
||||
# ── Constants ────────────────────────────────────────────────────────────────
|
||||
|
||||
CONTAINER_NAME = "mtp-bench-server"
|
||||
HEALTH_TIMEOUT = 180 # seconds to wait for server readiness
|
||||
HEALTH_INTERVAL = 3 # seconds between health polls
|
||||
COOLDOWN = 5 # seconds between runs
|
||||
BENCH_SCRIPT = Path(__file__).parent / "mtp-bench.py"
|
||||
|
||||
|
||||
# ── Model discovery ──────────────────────────────────────────────────────────
|
||||
|
||||
def discover_models(models_dir: Path) -> list[dict]:
|
||||
"""Scan models_dir for GGUF files with 'MTP' in their path."""
|
||||
models = []
|
||||
if not models_dir.is_dir():
|
||||
print(f"❌ Models directory not found: {models_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
for gguf in sorted(models_dir.rglob("*.gguf")):
|
||||
rel = gguf.relative_to(models_dir)
|
||||
# Must have MTP somewhere in the path (case-insensitive)
|
||||
if "mtp" not in str(rel).lower():
|
||||
continue
|
||||
# Skip non-first shards of multi-shard models
|
||||
name = gguf.name
|
||||
if "-000" in name and "-00001-of-" not in name:
|
||||
continue
|
||||
|
||||
# Derive display name from parent directory or filename
|
||||
if gguf.parent != models_dir:
|
||||
display_name = gguf.parent.name
|
||||
else:
|
||||
display_name = gguf.stem
|
||||
|
||||
models.append({
|
||||
"name": display_name,
|
||||
"gguf": str(rel),
|
||||
})
|
||||
|
||||
return models
|
||||
|
||||
|
||||
# ── System info ──────────────────────────────────────────────────────────────
|
||||
|
||||
def capture_system_info(results_dir: Path):
|
||||
"""Write system_info.json if it doesn't exist."""
|
||||
path = results_dir / "system_info.json"
|
||||
if path.exists():
|
||||
return
|
||||
|
||||
def get_distro():
|
||||
try:
|
||||
with open("/etc/os-release") as f:
|
||||
for line in f:
|
||||
if line.startswith("PRETTY_NAME="):
|
||||
return line.split("=", 1)[1].strip().strip('"')
|
||||
except Exception:
|
||||
pass
|
||||
return "Linux"
|
||||
|
||||
def get_linux_firmware():
|
||||
try:
|
||||
r = subprocess.run(["rpm", "-q", "linux-firmware"],
|
||||
capture_output=True, text=True)
|
||||
if r.returncode == 0:
|
||||
return r.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
info = {
|
||||
"distro": get_distro(),
|
||||
"kernel": platform.release(),
|
||||
"linux_firmware": get_linux_firmware(),
|
||||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
path.write_text(json.dumps(info, indent=2) + "\n")
|
||||
print(f"📋 Captured system info → {path}")
|
||||
|
||||
|
||||
# ── Container lifecycle ──────────────────────────────────────────────────────
|
||||
|
||||
def stop_container():
|
||||
"""Stop and remove the benchmark container if it exists."""
|
||||
subprocess.run(["podman", "stop", "-t", "5", CONTAINER_NAME],
|
||||
capture_output=True)
|
||||
subprocess.run(["podman", "rm", "-f", CONTAINER_NAME],
|
||||
capture_output=True)
|
||||
|
||||
|
||||
def start_server(toolbox: dict, gguf: str, models_dir: Path,
|
||||
spec_flags: list[str], port: int) -> bool:
|
||||
"""Start llama-server in a podman container. Returns True on success."""
|
||||
stop_container()
|
||||
|
||||
cmd = [
|
||||
"podman", "run", "--rm", "-d",
|
||||
"--name", CONTAINER_NAME,
|
||||
"--security-opt", "label=disable",
|
||||
"--userns=keep-id",
|
||||
*toolbox["engine_args"],
|
||||
"-v", f"{models_dir}:/models:ro",
|
||||
"-p", f"127.0.0.1:{port}:{port}",
|
||||
toolbox["image"],
|
||||
"llama-server",
|
||||
"-m", f"/models/{gguf}",
|
||||
"-c", "12288",
|
||||
"-ngl", "999",
|
||||
"--host", "0.0.0.0",
|
||||
"--port", str(port),
|
||||
"--no-mmap",
|
||||
"-fa", "1",
|
||||
"--jinja",
|
||||
*spec_flags,
|
||||
]
|
||||
|
||||
print(f" 🐳 Starting container...")
|
||||
print(f" {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print(f" ❌ Failed to start container: {result.stderr.strip()}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def wait_for_health(port: int) -> bool:
|
||||
"""Poll /health until the server is ready."""
|
||||
url = f"http://127.0.0.1:{port}/health"
|
||||
deadline = time.time() + HEALTH_TIMEOUT
|
||||
print(f" ⏳ Waiting for server health ({HEALTH_TIMEOUT}s timeout)...", end="", flush=True)
|
||||
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
req = request.Request(url, method="GET")
|
||||
with request.urlopen(req, timeout=5) as r:
|
||||
if r.status == 200:
|
||||
data = json.loads(r.read())
|
||||
if data.get("status") == "ok":
|
||||
print(f" ✅ ready")
|
||||
return True
|
||||
except (URLError, OSError, json.JSONDecodeError):
|
||||
pass
|
||||
print(".", end="", flush=True)
|
||||
time.sleep(HEALTH_INTERVAL)
|
||||
|
||||
print(f" ❌ timeout")
|
||||
# Dump container logs for debugging
|
||||
logs = subprocess.run(["podman", "logs", "--tail", "30", CONTAINER_NAME],
|
||||
capture_output=True, text=True)
|
||||
if logs.stdout:
|
||||
print(f" 📝 Last 30 lines of server logs:")
|
||||
for line in logs.stdout.strip().split("\n")[-15:]:
|
||||
print(f" {line}")
|
||||
return False
|
||||
|
||||
|
||||
# ── Benchmark execution ──────────────────────────────────────────────────────
|
||||
|
||||
def run_benchmark(port: int, out_path: Path) -> dict | None:
|
||||
"""Run mtp-bench.py as subprocess. Returns parsed JSON or None."""
|
||||
cmd = [
|
||||
sys.executable, str(BENCH_SCRIPT),
|
||||
"--url", f"http://127.0.0.1:{port}",
|
||||
"--out", str(out_path),
|
||||
]
|
||||
print(f" 🔬 Running benchmark...")
|
||||
result = subprocess.run(cmd, capture_output=False)
|
||||
if result.returncode != 0:
|
||||
print(f" ❌ Benchmark failed (exit {result.returncode})")
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(out_path.read_text())
|
||||
except Exception as e:
|
||||
print(f" ❌ Failed to read results: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ── Results handling ─────────────────────────────────────────────────────────
|
||||
|
||||
def save_result(out_path: Path, model: dict, toolbox_name: str,
|
||||
mode: str, spec_flags: list[str], bench_data: dict):
|
||||
"""Wrap mtp-bench output with run metadata and save."""
|
||||
wrapped = {
|
||||
"model": model["name"],
|
||||
"gguf": model["gguf"],
|
||||
"toolbox": toolbox_name,
|
||||
"mode": mode,
|
||||
"spec_flags": " ".join(spec_flags) if spec_flags else "(none)",
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
**bench_data,
|
||||
}
|
||||
out_path.write_text(json.dumps(wrapped, indent=2) + "\n")
|
||||
|
||||
|
||||
def print_summary(results_dir: Path):
|
||||
"""Read all result JSONs and print a summary table."""
|
||||
results = []
|
||||
for f in sorted(results_dir.glob("*.json")):
|
||||
if f.name == "system_info.json" or f.name == "summary.json":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(f.read_text())
|
||||
if "aggregate" in data:
|
||||
results.append(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not results:
|
||||
print("\n📊 No results found.")
|
||||
return
|
||||
|
||||
# Compute average tok/s per result
|
||||
for r in results:
|
||||
prompts = r.get("results", [])
|
||||
if prompts:
|
||||
r["_avg_toks"] = sum(p.get("predicted_per_second", 0) for p in prompts) / len(prompts)
|
||||
else:
|
||||
r["_avg_toks"] = 0
|
||||
|
||||
# Build baseline lookup for speedup calculation
|
||||
baselines = {}
|
||||
for r in results:
|
||||
if r["mode"] == "baseline":
|
||||
key = (r["model"], r["toolbox"])
|
||||
baselines[key] = r["_avg_toks"]
|
||||
|
||||
# Print table
|
||||
print("\n" + "=" * 100)
|
||||
print(f"{'Model':<30} {'Toolbox':<20} {'Mode':<10} {'Avg tok/s':>10} {'Accept%':>9} {'Wall(s)':>8} {'Speedup':>8}")
|
||||
print("-" * 100)
|
||||
|
||||
for r in results:
|
||||
agg = r.get("aggregate", {})
|
||||
accept = agg.get("aggregate_accept_rate")
|
||||
wall = agg.get("wall_s_total", 0)
|
||||
accept_str = f"{accept * 100:.1f}%" if accept is not None else "—"
|
||||
avg_toks = r["_avg_toks"]
|
||||
|
||||
# Speedup relative to baseline
|
||||
baseline_key = (r["model"], r["toolbox"])
|
||||
baseline_toks = baselines.get(baseline_key)
|
||||
if baseline_toks and baseline_toks > 0:
|
||||
speedup = f"{avg_toks / baseline_toks:.2f}×"
|
||||
else:
|
||||
speedup = "—"
|
||||
|
||||
print(f"{r['model']:<30} {r['toolbox']:<20} {r['mode']:<10} {avg_toks:>10.1f} {accept_str:>9} {wall:>8.1f} {speedup:>8}")
|
||||
|
||||
print("=" * 100)
|
||||
|
||||
# Write summary.json
|
||||
summary_data = []
|
||||
for r in results:
|
||||
agg = r.get("aggregate", {})
|
||||
summary_data.append({
|
||||
"model": r["model"],
|
||||
"toolbox": r["toolbox"],
|
||||
"mode": r["mode"],
|
||||
"avg_tok_s": round(r["_avg_toks"], 1),
|
||||
"accept_rate": agg.get("aggregate_accept_rate"),
|
||||
"wall_s_total": agg.get("wall_s_total"),
|
||||
})
|
||||
|
||||
summary_path = results_dir / "summary.json"
|
||||
summary_path.write_text(json.dumps(summary_data, indent=2) + "\n")
|
||||
print(f"\n📄 Summary written to {summary_path}")
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="MTP Benchmark Runner")
|
||||
ap.add_argument("--models-dir", type=Path, default=Path.home() / "models",
|
||||
help="Directory containing GGUF models (default: ~/models)")
|
||||
ap.add_argument("--model", type=str, default=None,
|
||||
help="Filter: only run models whose name contains this string")
|
||||
ap.add_argument("--toolbox", type=str, default=None,
|
||||
help="Filter: only run this toolbox (e.g. 'vulkan-radv-mtp')")
|
||||
ap.add_argument("--port", type=int, default=8080,
|
||||
help="Port for llama-server (default: 8080)")
|
||||
args = ap.parse_args()
|
||||
|
||||
models_dir = args.models_dir.expanduser().resolve()
|
||||
results_dir = Path(__file__).parent / "results-mtp"
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Discover models
|
||||
models = discover_models(models_dir)
|
||||
if args.model:
|
||||
models = [m for m in models if args.model.lower() in m["name"].lower()]
|
||||
|
||||
if not models:
|
||||
print(f"❌ No MTP models found in {models_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# Filter toolboxes
|
||||
toolboxes = TOOLBOXES
|
||||
if args.toolbox:
|
||||
if args.toolbox not in TOOLBOXES:
|
||||
print(f"❌ Unknown toolbox: {args.toolbox}")
|
||||
print(f" Available: {', '.join(TOOLBOXES.keys())}")
|
||||
sys.exit(1)
|
||||
toolboxes = {args.toolbox: TOOLBOXES[args.toolbox]}
|
||||
|
||||
# Print run plan
|
||||
print(f"\n🔍 Discovered {len(models)} MTP model(s):")
|
||||
for m in models:
|
||||
print(f" • {m['name']} → {m['gguf']}")
|
||||
|
||||
print(f"\n🧰 Toolboxes: {', '.join(toolboxes.keys())}")
|
||||
print(f"📊 Modes: {', '.join(MODES.keys())}")
|
||||
total = len(models) * len(toolboxes) * len(MODES)
|
||||
print(f"📋 Total runs: {total}\n")
|
||||
|
||||
# Capture system info
|
||||
capture_system_info(results_dir)
|
||||
|
||||
# Run benchmarks
|
||||
run_count = 0
|
||||
for tb_name, tb_config in toolboxes.items():
|
||||
for model in models:
|
||||
for mode_name, spec_flags in MODES.items():
|
||||
run_count += 1
|
||||
out_file = results_dir / f"{model['name']}__{tb_name}__{mode_name}.json"
|
||||
|
||||
print(f"\n{'─' * 80}")
|
||||
print(f"▶ [{run_count}/{total}] {model['name']} | {tb_name} | {mode_name}")
|
||||
print(f" Output: {out_file.name}")
|
||||
|
||||
# Skip if results exist
|
||||
if out_file.exists():
|
||||
print(f" ⏩ Skipping — results already exist")
|
||||
continue
|
||||
|
||||
# Start server
|
||||
if not start_server(tb_config, model["gguf"], models_dir,
|
||||
spec_flags, args.port):
|
||||
stop_container()
|
||||
continue
|
||||
|
||||
# Wait for health
|
||||
if not wait_for_health(args.port):
|
||||
stop_container()
|
||||
continue
|
||||
|
||||
# Run benchmark
|
||||
bench_data = run_benchmark(args.port, out_file)
|
||||
|
||||
# Stop server
|
||||
print(f" 🛑 Stopping container...")
|
||||
stop_container()
|
||||
|
||||
# Wrap result with metadata
|
||||
if bench_data:
|
||||
save_result(out_file, model, tb_name, mode_name,
|
||||
spec_flags, bench_data)
|
||||
print(f" ✅ Done — saved to {out_file.name}")
|
||||
else:
|
||||
print(f" ❌ No results collected")
|
||||
|
||||
# Cooldown
|
||||
if run_count < total:
|
||||
print(f" 💤 Cooldown ({COOLDOWN}s)...")
|
||||
time.sleep(COOLDOWN)
|
||||
|
||||
# Print summary
|
||||
print_summary(results_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user