Updates to Dockerfile buil docs

This commit is contained in:
Donato Capitella
2025-08-03 13:56:16 +01:00
parent 5d9e6f424e
commit 6c66edf0b7
9 changed files with 6 additions and 4 deletions
+162
View File
@@ -0,0 +1,162 @@
#!/usr/bin/env python3
import sys
import os
import re
import struct
import argparse
import math
from typing import Dict, Any, List
# GGUF constants
GGUF_MAGIC = 0x46554747
GGUF_VALUE_TYPE = {
0: "UINT8", 1: "INT8", 2: "UINT16", 3: "INT16", 4: "UINT32",
5: "INT32", 6: "FLOAT32", 7: "BOOL", 8: "STRING", 9: "ARRAY",
}
class GGUFMetadataReader:
"""A minimal reader to get only the necessary KV metadata for cache calculation."""
def __init__(self, path: str):
self.path = path
self.metadata: Dict[str, Any] = {}
def read(self):
with open(self.path, "rb") as f:
self.f = f
magic, _, _, metadata_kv_count = struct.unpack("<IIQQ", self.f.read(24))
if magic != GGUF_MAGIC: raise ValueError("Invalid GGUF magic number")
self._read_metadata(metadata_kv_count)
return self
def _read_string(self) -> str:
(length,) = struct.unpack("<Q", self.f.read(8))
return self.f.read(length).decode("utf-8", errors="replace")
def _read_value(self, value_type_idx: int):
value_type = GGUF_VALUE_TYPE.get(value_type_idx)
if not value_type: raise ValueError(f"Unknown GGUF value type: {value_type_idx}")
if value_type == "STRING": return self._read_string()
if value_type == "UINT32": return struct.unpack("<I", self.f.read(4))[0]
if value_type == "INT32": return struct.unpack("<i", self.f.read(4))[0]
self._skip_value(value_type_idx)
def _skip_value(self, value_type_idx: int):
value_type = GGUF_VALUE_TYPE.get(value_type_idx)
if not value_type: return
if value_type in ("UINT8", "INT8", "BOOL"): self.f.seek(1, 1)
elif value_type in ("UINT16", "INT16"): self.f.seek(2, 1)
elif value_type in ("UINT32", "INT32", "FLOAT32"): self.f.seek(4, 1)
elif value_type == "STRING":
(length,) = struct.unpack("<Q", self.f.read(8))
self.f.seek(length, 1)
elif value_type == "ARRAY":
(array_type_idx, count) = struct.unpack("<IQ", self.f.read(12))
type_map = {0:1, 1:1, 2:2, 3:2, 4:4, 5:4, 6:4, 7:1, 10:8, 11:8, 12:8}
element_size = type_map.get(array_type_idx)
if element_size: self.f.seek(count * element_size, 1)
else:
for _ in range(count): self._skip_value(8)
def _read_metadata(self, count: int):
keys_to_read = {"general.architecture", "general.name"}
arch_specific_keys_added = False
for _ in range(count):
key = self._read_string()
(value_type_idx,) = struct.unpack("<I", self.f.read(4))
if not arch_specific_keys_added and "general.architecture" in self.metadata:
prefix = self.metadata["general.architecture"]
keys_to_read.update({
f"{prefix}.block_count", f"{prefix}.context_length",
f"{prefix}.attention.head_count_kv", f"{prefix}.attention.key_length",
f"{prefix}.attention.value_length", f"{prefix}.attention.sliding_window_size"
})
arch_specific_keys_added = True
if key in keys_to_read:
self.metadata[key] = self._read_value(value_type_idx)
else:
self._skip_value(value_type_idx)
def get_total_model_size_from_disk(gguf_file_path: str) -> int:
"""Calculates the total model size by finding all parts on disk."""
match = re.search(r'-(\d{5})-of-(\d{5})\.gguf$', gguf_file_path, re.IGNORECASE)
if not match:
return os.path.getsize(gguf_file_path)
base_path = gguf_file_path[:match.start()]
total_parts_str = match.group(2)
total_parts = int(total_parts_str)
total_size, found_parts = 0, 0
for i in range(1, total_parts + 1):
part_file_name = f"{base_path}-{i:05d}-of-{total_parts_str}.gguf"
if os.path.exists(part_file_name):
total_size += os.path.getsize(part_file_name)
found_parts += 1
if found_parts != total_parts:
print(f"WARNING: Expected {total_parts} parts, found {found_parts}. Size calculation may be incomplete.", file=sys.stderr)
return total_size
def format_mem(size_bytes):
mib = size_bytes / (1024 * 1024)
if mib < 1024: return f"{mib:8.2f} MiB"
return f"{mib / 1024:8.2f} GiB"
def run_estimator(gguf_file: str, context_sizes: List[int], overhead_gib: float):
try:
reader = GGUFMetadataReader(gguf_file).read()
metadata = reader.metadata
prefix = metadata.get("general.architecture")
if not prefix: raise KeyError("Could not read 'general.architecture' from model metadata.")
model_size_bytes = get_total_model_size_from_disk(gguf_file)
overhead_bytes = int(overhead_gib * 1024**3)
n_layers = metadata[f"{prefix}.block_count"]
n_head_kv = metadata[f"{prefix}.attention.head_count_kv"]
training_context = metadata.get(f"{prefix}.context_length", 0)
n_embd_head_k = metadata[f"{prefix}.attention.key_length"]
n_embd_head_v = metadata[f"{prefix}.attention.value_length"]
swa_window_size = metadata.get(f"{prefix}.attention.sliding_window_size", 0)
is_scout_model = "scout" in metadata.get("general.name", "").lower()
if is_scout_model and swa_window_size == 0: n_layers_swa, n_layers_full, swa_window_size = 36, 12, 8192
elif swa_window_size > 0: n_layers_swa, n_layers_full = n_layers, 0
else: n_layers_swa, n_layers_full = 0, n_layers
print(f"\n--- Model '{metadata.get('general.name', 'N/A')}' ---")
if training_context > 0: print(f"Max Context: {training_context:,} tokens")
print(f"Model Size: {format_mem(model_size_bytes).strip()} (from file size)")
print(f"Incl. Overhead: {overhead_gib:.2f} GiB (for compute buffer, etc. adjustable via --overhead)")
if training_context > 0:
context_sizes = sorted(list(set([c for c in context_sizes if c <= training_context] + [c for c in [training_context] if c not in context_sizes])))
else: context_sizes = sorted(context_sizes)
bytes_per_token_per_layer = n_head_kv * (n_embd_head_k + n_embd_head_v) * 2
print("\n--- Memory Footprint Estimation ---")
print(f"{'Context Size':>15s} | {'Context Memory':>15s} | {'Est. Total VRAM':>15s}")
print("-" * 51)
for n_ctx in context_sizes:
mem_full = n_ctx * n_layers_full * bytes_per_token_per_layer
mem_swa = min(n_ctx, swa_window_size) * n_layers_swa * bytes_per_token_per_layer
kv_cache_bytes = mem_full + mem_swa
total_bytes = model_size_bytes + kv_cache_bytes + overhead_bytes
print(f"{n_ctx:>15,} | {format_mem(kv_cache_bytes):>15s} | {format_mem(total_bytes):>15s}")
except (FileNotFoundError, ValueError, struct.error, NotImplementedError, KeyError) as e:
print(f"\nError: {e}", file=sys.stderr)
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
description="Calculate VRAM requirements for a GGUF model, including a configurable overhead for compute buffers.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument("gguf_file", help="Path to the GGUF model file (any part of a multi-part model).")
parser.add_argument("-c", "--contexts", nargs='+', type=int, default=[4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576], help="Space-separated list of context sizes to calculate.")
parser.add_argument("--overhead", type=float, default=2.0, help="Estimated overhead in GiB for compute buffers, drivers, etc. (default: 2.0)")
args = parser.parse_args()
run_estimator(args.gguf_file, args.contexts, args.overhead)
if __name__ == "__main__":
main()