diff --git a/README.md b/README.md index 6ab2493..33c7924 100644 --- a/README.md +++ b/README.md @@ -50,8 +50,9 @@ AMD has recalled this update, but if you have already installed it, you must dow 6.1 [Test Configuration](#61-test-configuration) 6.2 [Kernel Parameters (tested on Fedora 42)](#62-kernel-parameters-tested-on-fedora-42) 6.3 [Ubuntu 24.04](#63-ubuntu-2404) -7. [More Documentation](#7-more-documentation) -8. [References](#8-references) +7. [Distributed Inference on Strix Halo Clusters](#7-distributed-inference-on-strix-halo-clusters) +8. [More Documentation](#8-more-documentation) +9. [References](#9-references) ## Quick Answers (Read This First) @@ -380,13 +381,39 @@ Follow this guide by TechnigmaAI for a working configuration on Ubuntu 24.04: [https://github.com/technigmaai/technigmaai-wiki/wiki/AMD-Ryzen-AI-Max--395:-GTT--Memory-Step%E2%80%90by%E2%80%90Step-Instructions-(Ubuntu-24.04)](https://github.com/technigmaai/technigmaai-wiki/wiki/AMD-Ryzen-AI-Max--395:-GTT--Memory-Step%E2%80%90by%E2%80%90Step-Instructions-%28Ubuntu-24.04%29) -## 7. More Documentation +## 7. Distributed Inference on Strix Halo Clusters + +You can use the included `run_distributed_llama.py` script to run models on a cluster of Strix Halo machines. + +### Setup +1. **Install Toolboxes**: Install the required toolboxes on each node (main and workers). +2. **Download Models**: Download the model weights to the main node. +3. **SSH Configuration**: Configure SSH passwordless authentication from the main node to all other nodes. The script relies on being able to SSH into workers without user interaction. + ```bash + ssh-copy-id user@worker-node-ip + ``` + +### Execution +Run the script on the main node **outside of the toolbox**: + +```bash +python3 run_distributed_llama.py +``` + +This will launch the TUI, where you can: +1. Configure the list of worker nodes (IPs). +2. Select the model and toolbox version. +3. Start the distributed inference. + +The script automatically starts the necessary toolbox containers locally and on the remote nodes to handle the inference. + +## 8. More Documentation * [docs/benchmarks.md](docs/benchmarks.md): Full benchmark logs, model list, parsed results * [docs/vram-estimator.md](docs/vram-estimator.md): Memory planning, practical example runs * [docs/building.md](docs/building.md): Local build, toolbox customization, advanced use -## 8. References +## 9. References * The main reference for AMD Ryzen AI MAX home labs, by deseven (there's also a Discord server): [https://strixhalo-homelab.d7.wtf/](https://strixhalo-homelab.d7.wtf/) * Most comprehesive repostiry of test builds for Strix Halo by lhl -> [https://github.com/lhl/strix-halo-testing/tree/main](https://github.com/lhl/strix-halo-testing/tree/main) diff --git a/run_distributed_llama.py b/run_distributed_llama.py new file mode 100755 index 0000000..a160e08 --- /dev/null +++ b/run_distributed_llama.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +import sys +import os +import shutil +import tempfile +import subprocess +import time +import signal +from pathlib import Path + +# --- Configuration & Defaults --- +SCRIPT_DIR = Path(__file__).parent.resolve() +DEFAULT_TOOLBOX = "rocm7-nightlies" +TOOLBOX_IMAGES = { + "rocm6_4_4": "llama-rocm-6.4.4", + "rocm7_1_1": "llama-rocm-7.1.1", + "rocm7-nightlies": "llama-rocm7-nightlies", + "vulkan_amdvlk": "llama-vulkan-amdvlk", + "vulkan_radv": "llama-vulkan-radv", +} + +MODES = ["llama-server", "llama-cli", "llama-bench"] +DEFAULT_MODE = "llama-server" + +# Default RPC Hosts +DEFAULT_HOSTS = [ + ("192.168.100.11", True), + ("192.168.100.12", True), + ("192.168.100.13", True), +] + +REMOTE_PORT = os.getenv("REMOTE_PORT", "22") +RPC_PORT = os.getenv("RPC_PORT", "50052") +LOCAL_HOST_PORT = "8080" + + +# --- Helper Functions --- + +def check_dependencies(): + if not shutil.which("dialog"): + print("Error: 'dialog' is required. Please install it (e.g., sudo apt-get install dialog).") + sys.exit(1) + +def run_dialog(args): + """Runs dialog and returns stderr (selection), and exit code.""" + with tempfile.NamedTemporaryFile(mode="w+") as tf: + cmd = ["dialog"] + args + try: + res = subprocess.run(cmd, stderr=tf, check=False) # Check False to handle exit codes + tf.seek(0) + return tf.read().strip(), res.returncode + except Exception: + return None, 1 + +def show_msg(title, msg): + run_dialog(["--title", title, "--msgbox", msg, "10", "60"]) + +# --- Custom File Picker --- + +def get_directory_contents(path): + try: + if not os.path.isdir(path): + return [], [] + + entries = os.listdir(path) + dirs = [] + files = [] + + for e in entries: + full_path = os.path.join(path, e) + if os.path.isdir(full_path): + dirs.append(e) + elif e.endswith(".gguf"): # Filter for GGUF + files.append(e) + + dirs.sort() + files.sort() + return dirs, files + except PermissionError: + return [], [] + +def custom_file_picker(start_path): + current_path = os.path.abspath(start_path) + if not os.path.isdir(current_path): + current_path = os.getcwd() + + while True: + dirs, files = get_directory_contents(current_path) + + menu_items = [] + + # Parent directory option + if current_path != "/": + menu_items.extend(["..", "Parent Directory"]) + + # Directories + for d in dirs: + menu_items.extend([d + "/", ""]) + + # Files + for f in files: + menu_items.extend([f, ""]) + + if not menu_items: + menu_items.extend([".", "Empty Directory"]) + + # Title shows current path truncated if needed + pretty_path = current_path + if len(pretty_path) > 50: + pretty_path = "..." + pretty_path[-47:] + + selection, code = run_dialog([ + "--title", f"Select GGUF File", + "--backtitle", f"Current: {pretty_path}", + "--menu", "Navigate directories and select a .gguf file:", "20", "70", "12", + *menu_items + ]) + + if code != 0: # Cancel/Escape + return None + + clean_selection = selection.strip() + + if clean_selection == "..": + current_path = os.path.dirname(current_path) + elif clean_selection.endswith("/"): + # Enter directory + dir_name = clean_selection[:-1] # Remove slash + current_path = os.path.join(current_path, dir_name) + elif clean_selection == ".": + pass # Stay here + else: + # File selected + return os.path.join(current_path, clean_selection) + +# --- Main Logic --- + +class AppState: + def __init__(self): + self.model_path = "" + self.toolbox = DEFAULT_TOOLBOX + self.mode = DEFAULT_MODE + # List of [ip, enabled] + self.hosts = [list(h) for h in DEFAULT_HOSTS] + self.context_size = None # None means default (do not pass -c) + + @property + def active_hosts(self): + return [h[0] for h in self.hosts if h[1]] + +def select_model(state): + start_path = state.model_path if state.model_path else os.getcwd() + if os.path.isfile(start_path): + start_path = os.path.dirname(start_path) + + selection = custom_file_picker(start_path) + if selection: + state.model_path = selection + +def select_toolbox(state): + menu_items = [] + for key in TOOLBOX_IMAGES.keys(): + menu_items.extend([key, TOOLBOX_IMAGES[key]]) + + selection, code = run_dialog([ + "--title", "Select Toolbox", + "--menu", "Choose the container environment:", "15", "60", "8", + *menu_items + ]) + if code == 0 and selection: + state.toolbox = selection + +def select_mode(state): + menu_items = [] + for m in MODES: + menu_items.extend([m, ""]) + + selection, code = run_dialog([ + "--title", "Select Execution Mode", + "--menu", "Choose how to run the model:", "12", "50", "5", + *menu_items + ]) + if code == 0 and selection: + state.mode = selection + +def select_context(state): + current = str(state.context_size) if state.context_size else "" + selection, code = run_dialog([ + "--title", "Context Size", + "--inputbox", "Enter context size (e.g. 4096, 8192).\nLeave empty for model default:", "10", "60", + current + ]) + if code == 0: + val = selection.strip() + if val.isdigit(): + state.context_size = int(val) + else: + state.context_size = None + +def add_server(state): + selection, code = run_dialog([ + "--title", "Add Server", + "--inputbox", "Enter new server IP address:", "10", "50" + ]) + if code == 0: + ip = selection.strip() + if ip: + # Default to enabled + state.hosts.append([ip, True]) + +def remove_server(state): + items = [] + for i, (ip, enabled) in enumerate(state.hosts): + items.extend([str(i), ip]) + + if not items: + show_msg("Info", "No servers to remove.") + return + + selection, code = run_dialog([ + "--title", "Remove Server", + "--menu", "Select server to remove:", "15", "50", "5", + *items + ]) + + if code == 0 and selection: + idx = int(selection) + if 0 <= idx < len(state.hosts): + del state.hosts[idx] + +def edit_server(state): + items = [] + for i, (ip, enabled) in enumerate(state.hosts): + items.extend([str(i), ip]) + + if not items: + show_msg("Info", "No servers to edit.") + return + + selection, code = run_dialog([ + "--title", "Edit Server", + "--menu", "Select server to edit:", "15", "50", "5", + *items + ]) + + if code == 0 and selection: + idx = int(selection) + if 0 <= idx < len(state.hosts): + current_ip = state.hosts[idx][0] + new_ip, code2 = run_dialog([ + "--title", "Edit Server IP", + "--inputbox", "Enter new IP address:", "10", "50", + current_ip + ]) + if code2 == 0: + clean_ip = new_ip.strip() + if clean_ip: + state.hosts[idx][0] = clean_ip + +def toggle_servers(state): + # checklist: item tag, item string, status (on/off) + items = [] + for i, (ip, enabled) in enumerate(state.hosts): + status = "on" if enabled else "off" + items.extend([str(i), ip, status]) + + if not items: + show_msg("Info", "No servers to configure. Add some first.") + return + + selection_str, code = run_dialog([ + "--title", "Toggle Active Servers", + "--checklist", "Select active servers (Space to toggle):", "15", "50", "5", + *items + ]) + + if code == 0: + # Reset all to False first + for h in state.hosts: + h[1] = False + + if selection_str: + # e.g. "0 2" + indices = [int(x.strip('"')) for x in selection_str.split()] + for idx in indices: + if 0 <= idx < len(state.hosts): + state.hosts[idx][1] = True + +def configure_servers(state): + while True: + menu = [ + "1", "Toggle Active Servers", + "2", "Add Server", + "3", "Remove Server", + "4", "Edit Server", + "5", "Back" + ] + + selection, code = run_dialog([ + "--title", "Manage Remote Servers", + "--menu", "Choose an action:", "15", "50", "5", + *menu + ]) + + if code != 0 or selection == "5": + break + + if selection == "1": + toggle_servers(state) + elif selection == "2": + add_server(state) + elif selection == "3": + remove_server(state) + elif selection == "4": + edit_server(state) + +def run_distributed(state): + if not state.model_path or not os.path.exists(state.model_path): + show_msg("Error", f"Model file not found:\n{state.model_path}") + return + + if not state.active_hosts: + show_msg("Error", "No remote servers selected.") + return + + image = TOOLBOX_IMAGES[state.toolbox] + active_ips = state.active_hosts + + # Clear screen for execution output + subprocess.run(["clear"]) + print(f"=== Starting Distributed Run ===") + print(f"Model: {state.model_path}") + print(f"Toolbox: {state.toolbox} ({image})") + print(f"Mode: {state.mode}") + print(f"Context: {state.context_size if state.context_size else 'Default'}") + print(f"Hosts: {active_ips}") + print("--------------------------------") + + remote_pids = [] + + def cleanup(): + print("\nCleaning up...") + for i, ip in enumerate(active_ips): + if i < len(remote_pids): + pid = remote_pids[i] + if pid: + print(f"Killing remote RPC on {ip} (PID: {pid})...") + subprocess.run( + ["ssh", "-p", REMOTE_PORT, ip, f"kill -9 {pid} 2>/dev/null || true; pkill -9 -f rpc-server || true"], + stderr=subprocess.DEVNULL + ) + + # Register signal handler for cleanup + def signal_handler(sig, frame): + cleanup() + sys.exit(0) + signal.signal(signal.SIGINT, signal_handler) + + try: + rpc_arg_parts = [] + + # 1. Start Remote RPC Servers + for ip in active_ips: + print(f"-> Starting RPC server on {ip}...") + + # Using bash heredoc via ssh to start background process and print PID + # We assume 'toolbox' command exists on remote + cmd_str = f""" + set -euo pipefail + pkill -9 -f rpc-server || true + nohup toolbox run -c {image} -- rpc-server -H 0.0.0.0 -p {RPC_PORT} -c > /tmp/rpc-server-{ip}.log 2>&1 < /dev/null & + echo $! + """ + + res = subprocess.run( + ["ssh", "-p", REMOTE_PORT, ip, "bash -s"], + input=cmd_str, text=True, capture_output=True + ) + + if res.returncode != 0: + print(f"[ERROR] SSH failed for {ip}: {res.stderr}") + cleanup() + return + + pid = res.stdout.strip() + # Basic validation + if not pid.isdigit(): + lines = pid.splitlines() + if lines and lines[-1].isdigit(): + pid = lines[-1] + else: + print(f"[ERROR] Invalid PID returned from {ip}: {pid}") + cleanup() + return + + remote_pids.append(pid) + print(f" PID: {pid}") + + # Wait for port check + print(f" Waiting for port {RPC_PORT}...", end="", flush=True) + ready = False + for _ in range(30): + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(1) + try: + s.connect((ip, int(RPC_PORT))) + s.close() + ready = True + print(" OK.") + break + except: + time.sleep(1) + + if not ready: + print(" TIMEOUT.") + print(f"[ERROR] Failed to connect to {ip}:{RPC_PORT}") + cleanup() + return + + rpc_arg_parts.append(f"{ip}:{RPC_PORT}") + + rpc_arg = ",".join(rpc_arg_parts) + print(f"All servers ready. RPC Arg: {rpc_arg}") + print(f"Starting Local {state.mode}...") + print("--------------------------------") + + # 2. Run Local Executable + # Base arguments for all modes + base_args = [ + "toolbox", "run", "-c", image, "--", + state.mode, + "-m", state.model_path, + "--rpc", rpc_arg + ] + + if state.mode == "llama-server": + # Llama Server specific + extra_args = [ + "--no-mmap", + "-fa", "1", + "--host", "0.0.0.0", + "--port", LOCAL_HOST_PORT + ] + if state.context_size: + extra_args.extend(["-c", str(state.context_size)]) + + elif state.mode == "llama-cli": + # Llama CLI specific (interactive or basic run) + # User requested -mmp 0 and -fa 1 + extra_args = [ + "--no-mmap", + "-fa", "1", + "-cnv", # Conversation mode seems appropriate for CLI + "-p", "You are a helpful assistant." + ] + if state.context_size: + extra_args.extend(["-c", str(state.context_size)]) + + elif state.mode == "llama-bench": + # Llama Bench specific + # User requested -mmp 0 and -fa 1 (Note: llama-bench uses different arg names sometimes?) + # llama-bench: -mmp (mmap) + extra_args = [ + "-mmp", "0", + "-fa", "1" + ] + # bench usually controls context via other flags, user didn't ask for it here. + else: + extra_args = [] + + local_cmd = base_args + extra_args + + print(f"CMD: {' '.join(local_cmd)}") + + proc = subprocess.Popen(local_cmd) + proc.wait() + + except Exception as e: + print(f"\n[EXCEPTION] {e}") + finally: + cleanup() + + input("\nRun complete. Press Enter to return to menu...") + + +def main_menu(): + state = AppState() + + while True: + model_display = Path(state.model_path).name if state.model_path else "(None)" + servers_display = f"{len(state.active_hosts)} Active" + context_display = str(state.context_size) if state.context_size else "Default" + + menu = [ + "--clear", "--backtitle", "AMD Strix Halo - Distributed Llama", + "--title", "Main Menu", + "--menu", "Select an option to configure or run:", "20", "60", "7", + "1", f"Model: {model_display}", + "2", f"Toolbox: {state.toolbox}", + "3", f"Servers: {servers_display}", + "4", f"Mode: {state.mode}", + "5", f"Context: {context_display}", + "6", "RUN DISTRIBUTED SERVER", + "7", "Exit" + ] + + choice, code = run_dialog(menu) + + if code != 0: # Cancel/Esc + break + + if choice == "1": + select_model(state) + elif choice == "2": + select_toolbox(state) + elif choice == "3": + configure_servers(state) + elif choice == "4": + select_mode(state) + elif choice == "5": + select_context(state) + elif choice == "6": + run_distributed(state) + elif choice == "7": + break + + subprocess.run(["clear"]) + exit(0) + +if __name__ == "__main__": + check_dependencies() + try: + main_menu() + except KeyboardInterrupt: + subprocess.run(["clear"]) + sys.exit(0)