added script to manage cluster

This commit is contained in:
Donato Capitella
2026-01-14 17:01:00 +00:00
parent 6d70dfc73b
commit 8da5395366
2 changed files with 568 additions and 4 deletions
+31 -4
View File
@@ -50,8 +50,9 @@ AMD has recalled this update, but if you have already installed it, you must dow
6.1 [Test Configuration](#61-test-configuration)
6.2 [Kernel Parameters (tested on Fedora 42)](#62-kernel-parameters-tested-on-fedora-42)
6.3 [Ubuntu 24.04](#63-ubuntu-2404)
7. [More Documentation](#7-more-documentation)
8. [References](#8-references)
7. [Distributed Inference on Strix Halo Clusters](#7-distributed-inference-on-strix-halo-clusters)
8. [More Documentation](#8-more-documentation)
9. [References](#9-references)
## Quick Answers (Read This First)
@@ -380,13 +381,39 @@ Follow this guide by TechnigmaAI for a working configuration on Ubuntu 24.04:
[https://github.com/technigmaai/technigmaai-wiki/wiki/AMD-Ryzen-AI-Max--395:-GTT--Memory-Step%E2%80%90by%E2%80%90Step-Instructions-(Ubuntu-24.04)](https://github.com/technigmaai/technigmaai-wiki/wiki/AMD-Ryzen-AI-Max--395:-GTT--Memory-Step%E2%80%90by%E2%80%90Step-Instructions-%28Ubuntu-24.04%29)
## 7. More Documentation
## 7. Distributed Inference on Strix Halo Clusters
You can use the included `run_distributed_llama.py` script to run models on a cluster of Strix Halo machines.
### Setup
1. **Install Toolboxes**: Install the required toolboxes on each node (main and workers).
2. **Download Models**: Download the model weights to the main node.
3. **SSH Configuration**: Configure SSH passwordless authentication from the main node to all other nodes. The script relies on being able to SSH into workers without user interaction.
```bash
ssh-copy-id user@worker-node-ip
```
### Execution
Run the script on the main node **outside of the toolbox**:
```bash
python3 run_distributed_llama.py
```
This will launch the TUI, where you can:
1. Configure the list of worker nodes (IPs).
2. Select the model and toolbox version.
3. Start the distributed inference.
The script automatically starts the necessary toolbox containers locally and on the remote nodes to handle the inference.
## 8. More Documentation
* [docs/benchmarks.md](docs/benchmarks.md): Full benchmark logs, model list, parsed results
* [docs/vram-estimator.md](docs/vram-estimator.md): Memory planning, practical example runs
* [docs/building.md](docs/building.md): Local build, toolbox customization, advanced use
## 8. References
## 9. References
* The main reference for AMD Ryzen AI MAX home labs, by deseven (there's also a Discord server): [https://strixhalo-homelab.d7.wtf/](https://strixhalo-homelab.d7.wtf/)
* Most comprehesive repostiry of test builds for Strix Halo by lhl -> [https://github.com/lhl/strix-halo-testing/tree/main](https://github.com/lhl/strix-halo-testing/tree/main)
+537
View File
@@ -0,0 +1,537 @@
#!/usr/bin/env python3
import sys
import os
import shutil
import tempfile
import subprocess
import time
import signal
from pathlib import Path
# --- Configuration & Defaults ---
SCRIPT_DIR = Path(__file__).parent.resolve()
DEFAULT_TOOLBOX = "rocm7-nightlies"
TOOLBOX_IMAGES = {
"rocm6_4_4": "llama-rocm-6.4.4",
"rocm7_1_1": "llama-rocm-7.1.1",
"rocm7-nightlies": "llama-rocm7-nightlies",
"vulkan_amdvlk": "llama-vulkan-amdvlk",
"vulkan_radv": "llama-vulkan-radv",
}
MODES = ["llama-server", "llama-cli", "llama-bench"]
DEFAULT_MODE = "llama-server"
# Default RPC Hosts
DEFAULT_HOSTS = [
("192.168.100.11", True),
("192.168.100.12", True),
("192.168.100.13", True),
]
REMOTE_PORT = os.getenv("REMOTE_PORT", "22")
RPC_PORT = os.getenv("RPC_PORT", "50052")
LOCAL_HOST_PORT = "8080"
# --- Helper Functions ---
def check_dependencies():
if not shutil.which("dialog"):
print("Error: 'dialog' is required. Please install it (e.g., sudo apt-get install dialog).")
sys.exit(1)
def run_dialog(args):
"""Runs dialog and returns stderr (selection), and exit code."""
with tempfile.NamedTemporaryFile(mode="w+") as tf:
cmd = ["dialog"] + args
try:
res = subprocess.run(cmd, stderr=tf, check=False) # Check False to handle exit codes
tf.seek(0)
return tf.read().strip(), res.returncode
except Exception:
return None, 1
def show_msg(title, msg):
run_dialog(["--title", title, "--msgbox", msg, "10", "60"])
# --- Custom File Picker ---
def get_directory_contents(path):
try:
if not os.path.isdir(path):
return [], []
entries = os.listdir(path)
dirs = []
files = []
for e in entries:
full_path = os.path.join(path, e)
if os.path.isdir(full_path):
dirs.append(e)
elif e.endswith(".gguf"): # Filter for GGUF
files.append(e)
dirs.sort()
files.sort()
return dirs, files
except PermissionError:
return [], []
def custom_file_picker(start_path):
current_path = os.path.abspath(start_path)
if not os.path.isdir(current_path):
current_path = os.getcwd()
while True:
dirs, files = get_directory_contents(current_path)
menu_items = []
# Parent directory option
if current_path != "/":
menu_items.extend(["..", "Parent Directory"])
# Directories
for d in dirs:
menu_items.extend([d + "/", "<DIR>"])
# Files
for f in files:
menu_items.extend([f, "<GGUF>"])
if not menu_items:
menu_items.extend([".", "Empty Directory"])
# Title shows current path truncated if needed
pretty_path = current_path
if len(pretty_path) > 50:
pretty_path = "..." + pretty_path[-47:]
selection, code = run_dialog([
"--title", f"Select GGUF File",
"--backtitle", f"Current: {pretty_path}",
"--menu", "Navigate directories and select a .gguf file:", "20", "70", "12",
*menu_items
])
if code != 0: # Cancel/Escape
return None
clean_selection = selection.strip()
if clean_selection == "..":
current_path = os.path.dirname(current_path)
elif clean_selection.endswith("/"):
# Enter directory
dir_name = clean_selection[:-1] # Remove slash
current_path = os.path.join(current_path, dir_name)
elif clean_selection == ".":
pass # Stay here
else:
# File selected
return os.path.join(current_path, clean_selection)
# --- Main Logic ---
class AppState:
def __init__(self):
self.model_path = ""
self.toolbox = DEFAULT_TOOLBOX
self.mode = DEFAULT_MODE
# List of [ip, enabled]
self.hosts = [list(h) for h in DEFAULT_HOSTS]
self.context_size = None # None means default (do not pass -c)
@property
def active_hosts(self):
return [h[0] for h in self.hosts if h[1]]
def select_model(state):
start_path = state.model_path if state.model_path else os.getcwd()
if os.path.isfile(start_path):
start_path = os.path.dirname(start_path)
selection = custom_file_picker(start_path)
if selection:
state.model_path = selection
def select_toolbox(state):
menu_items = []
for key in TOOLBOX_IMAGES.keys():
menu_items.extend([key, TOOLBOX_IMAGES[key]])
selection, code = run_dialog([
"--title", "Select Toolbox",
"--menu", "Choose the container environment:", "15", "60", "8",
*menu_items
])
if code == 0 and selection:
state.toolbox = selection
def select_mode(state):
menu_items = []
for m in MODES:
menu_items.extend([m, ""])
selection, code = run_dialog([
"--title", "Select Execution Mode",
"--menu", "Choose how to run the model:", "12", "50", "5",
*menu_items
])
if code == 0 and selection:
state.mode = selection
def select_context(state):
current = str(state.context_size) if state.context_size else ""
selection, code = run_dialog([
"--title", "Context Size",
"--inputbox", "Enter context size (e.g. 4096, 8192).\nLeave empty for model default:", "10", "60",
current
])
if code == 0:
val = selection.strip()
if val.isdigit():
state.context_size = int(val)
else:
state.context_size = None
def add_server(state):
selection, code = run_dialog([
"--title", "Add Server",
"--inputbox", "Enter new server IP address:", "10", "50"
])
if code == 0:
ip = selection.strip()
if ip:
# Default to enabled
state.hosts.append([ip, True])
def remove_server(state):
items = []
for i, (ip, enabled) in enumerate(state.hosts):
items.extend([str(i), ip])
if not items:
show_msg("Info", "No servers to remove.")
return
selection, code = run_dialog([
"--title", "Remove Server",
"--menu", "Select server to remove:", "15", "50", "5",
*items
])
if code == 0 and selection:
idx = int(selection)
if 0 <= idx < len(state.hosts):
del state.hosts[idx]
def edit_server(state):
items = []
for i, (ip, enabled) in enumerate(state.hosts):
items.extend([str(i), ip])
if not items:
show_msg("Info", "No servers to edit.")
return
selection, code = run_dialog([
"--title", "Edit Server",
"--menu", "Select server to edit:", "15", "50", "5",
*items
])
if code == 0 and selection:
idx = int(selection)
if 0 <= idx < len(state.hosts):
current_ip = state.hosts[idx][0]
new_ip, code2 = run_dialog([
"--title", "Edit Server IP",
"--inputbox", "Enter new IP address:", "10", "50",
current_ip
])
if code2 == 0:
clean_ip = new_ip.strip()
if clean_ip:
state.hosts[idx][0] = clean_ip
def toggle_servers(state):
# checklist: item tag, item string, status (on/off)
items = []
for i, (ip, enabled) in enumerate(state.hosts):
status = "on" if enabled else "off"
items.extend([str(i), ip, status])
if not items:
show_msg("Info", "No servers to configure. Add some first.")
return
selection_str, code = run_dialog([
"--title", "Toggle Active Servers",
"--checklist", "Select active servers (Space to toggle):", "15", "50", "5",
*items
])
if code == 0:
# Reset all to False first
for h in state.hosts:
h[1] = False
if selection_str:
# e.g. "0 2"
indices = [int(x.strip('"')) for x in selection_str.split()]
for idx in indices:
if 0 <= idx < len(state.hosts):
state.hosts[idx][1] = True
def configure_servers(state):
while True:
menu = [
"1", "Toggle Active Servers",
"2", "Add Server",
"3", "Remove Server",
"4", "Edit Server",
"5", "Back"
]
selection, code = run_dialog([
"--title", "Manage Remote Servers",
"--menu", "Choose an action:", "15", "50", "5",
*menu
])
if code != 0 or selection == "5":
break
if selection == "1":
toggle_servers(state)
elif selection == "2":
add_server(state)
elif selection == "3":
remove_server(state)
elif selection == "4":
edit_server(state)
def run_distributed(state):
if not state.model_path or not os.path.exists(state.model_path):
show_msg("Error", f"Model file not found:\n{state.model_path}")
return
if not state.active_hosts:
show_msg("Error", "No remote servers selected.")
return
image = TOOLBOX_IMAGES[state.toolbox]
active_ips = state.active_hosts
# Clear screen for execution output
subprocess.run(["clear"])
print(f"=== Starting Distributed Run ===")
print(f"Model: {state.model_path}")
print(f"Toolbox: {state.toolbox} ({image})")
print(f"Mode: {state.mode}")
print(f"Context: {state.context_size if state.context_size else 'Default'}")
print(f"Hosts: {active_ips}")
print("--------------------------------")
remote_pids = []
def cleanup():
print("\nCleaning up...")
for i, ip in enumerate(active_ips):
if i < len(remote_pids):
pid = remote_pids[i]
if pid:
print(f"Killing remote RPC on {ip} (PID: {pid})...")
subprocess.run(
["ssh", "-p", REMOTE_PORT, ip, f"kill -9 {pid} 2>/dev/null || true; pkill -9 -f rpc-server || true"],
stderr=subprocess.DEVNULL
)
# Register signal handler for cleanup
def signal_handler(sig, frame):
cleanup()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
try:
rpc_arg_parts = []
# 1. Start Remote RPC Servers
for ip in active_ips:
print(f"-> Starting RPC server on {ip}...")
# Using bash heredoc via ssh to start background process and print PID
# We assume 'toolbox' command exists on remote
cmd_str = f"""
set -euo pipefail
pkill -9 -f rpc-server || true
nohup toolbox run -c {image} -- rpc-server -H 0.0.0.0 -p {RPC_PORT} -c > /tmp/rpc-server-{ip}.log 2>&1 < /dev/null &
echo $!
"""
res = subprocess.run(
["ssh", "-p", REMOTE_PORT, ip, "bash -s"],
input=cmd_str, text=True, capture_output=True
)
if res.returncode != 0:
print(f"[ERROR] SSH failed for {ip}: {res.stderr}")
cleanup()
return
pid = res.stdout.strip()
# Basic validation
if not pid.isdigit():
lines = pid.splitlines()
if lines and lines[-1].isdigit():
pid = lines[-1]
else:
print(f"[ERROR] Invalid PID returned from {ip}: {pid}")
cleanup()
return
remote_pids.append(pid)
print(f" PID: {pid}")
# Wait for port check
print(f" Waiting for port {RPC_PORT}...", end="", flush=True)
ready = False
for _ in range(30):
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
try:
s.connect((ip, int(RPC_PORT)))
s.close()
ready = True
print(" OK.")
break
except:
time.sleep(1)
if not ready:
print(" TIMEOUT.")
print(f"[ERROR] Failed to connect to {ip}:{RPC_PORT}")
cleanup()
return
rpc_arg_parts.append(f"{ip}:{RPC_PORT}")
rpc_arg = ",".join(rpc_arg_parts)
print(f"All servers ready. RPC Arg: {rpc_arg}")
print(f"Starting Local {state.mode}...")
print("--------------------------------")
# 2. Run Local Executable
# Base arguments for all modes
base_args = [
"toolbox", "run", "-c", image, "--",
state.mode,
"-m", state.model_path,
"--rpc", rpc_arg
]
if state.mode == "llama-server":
# Llama Server specific
extra_args = [
"--no-mmap",
"-fa", "1",
"--host", "0.0.0.0",
"--port", LOCAL_HOST_PORT
]
if state.context_size:
extra_args.extend(["-c", str(state.context_size)])
elif state.mode == "llama-cli":
# Llama CLI specific (interactive or basic run)
# User requested -mmp 0 and -fa 1
extra_args = [
"--no-mmap",
"-fa", "1",
"-cnv", # Conversation mode seems appropriate for CLI
"-p", "You are a helpful assistant."
]
if state.context_size:
extra_args.extend(["-c", str(state.context_size)])
elif state.mode == "llama-bench":
# Llama Bench specific
# User requested -mmp 0 and -fa 1 (Note: llama-bench uses different arg names sometimes?)
# llama-bench: -mmp (mmap)
extra_args = [
"-mmp", "0",
"-fa", "1"
]
# bench usually controls context via other flags, user didn't ask for it here.
else:
extra_args = []
local_cmd = base_args + extra_args
print(f"CMD: {' '.join(local_cmd)}")
proc = subprocess.Popen(local_cmd)
proc.wait()
except Exception as e:
print(f"\n[EXCEPTION] {e}")
finally:
cleanup()
input("\nRun complete. Press Enter to return to menu...")
def main_menu():
state = AppState()
while True:
model_display = Path(state.model_path).name if state.model_path else "(None)"
servers_display = f"{len(state.active_hosts)} Active"
context_display = str(state.context_size) if state.context_size else "Default"
menu = [
"--clear", "--backtitle", "AMD Strix Halo - Distributed Llama",
"--title", "Main Menu",
"--menu", "Select an option to configure or run:", "20", "60", "7",
"1", f"Model: {model_display}",
"2", f"Toolbox: {state.toolbox}",
"3", f"Servers: {servers_display}",
"4", f"Mode: {state.mode}",
"5", f"Context: {context_display}",
"6", "RUN DISTRIBUTED SERVER",
"7", "Exit"
]
choice, code = run_dialog(menu)
if code != 0: # Cancel/Esc
break
if choice == "1":
select_model(state)
elif choice == "2":
select_toolbox(state)
elif choice == "3":
configure_servers(state)
elif choice == "4":
select_mode(state)
elif choice == "5":
select_context(state)
elif choice == "6":
run_distributed(state)
elif choice == "7":
break
subprocess.run(["clear"])
exit(0)
if __name__ == "__main__":
check_dependencies()
try:
main_menu()
except KeyboardInterrupt:
subprocess.run(["clear"])
sys.exit(0)