From 5f3608181b9544b3ba9c685432799521404dbada Mon Sep 17 00:00:00 2001 From: supermario_leo Date: Fri, 19 Jun 2026 22:44:26 +0800 Subject: [PATCH] feat(vram): model sliding-window attention in KV cache estimation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit estimate_kv_cache scaled the KV cache linearly with the full requested context for every model, so it over-counted VRAM for sliding-window-attention (SWA) models whose local-attention layers only cache the last `window` tokens. At long context this inflated the estimate enough to make the ranker demote models that actually fit (e.g. Gemma-3-27B at 128K: 12.7 GB KV estimated vs ~2.2 GB real). Add architecture-gated SWA modeling: - ModelInfo gains sliding_window and sliding_window_global_ratio. - The fetcher populates them only for architectures whose mainline runtimes actually honor interleaved SWA (Gemma-2/3, gpt-oss, Cohere2), read from HF config sliding_window/sliding_window_pattern, then GGUF metadata architecture, then a boundary-matched, conflict-guarded model-id fallback for config-less GGUF repos. - estimate_kv_cache blends global and windowed layers into an effective context length: global_ratio*ctx + (1-global_ratio)*min(ctx, window). Models outside the allowlist keep the full-context KV figure — including Mistral-7B-v0.1, whose config advertises a 4096 window that mainline runtimes ignore. The reduction can therefore only ever lower an estimate where the saving is real, and stays conservative everywhere else. sliding_window=None reproduces the previous formula exactly. Addresses the sliding-window-attention item in #25. Tests cover dense, pure-SWA, hybrid, the conservative no-window default, GGUF-metadata and id-hint resolution, merge/boundary false-positive guards, and cache round-trip. --- src/whichllm/engine/vram.py | 31 +++++- src/whichllm/models/fetcher.py | 153 +++++++++++++++++++++++++++ src/whichllm/models/types.py | 8 ++ tests/test_fetcher.py | 183 +++++++++++++++++++++++++++++++++ tests/test_vram.py | 86 ++++++++++++++++ 5 files changed, 460 insertions(+), 1 deletion(-) diff --git a/src/whichllm/engine/vram.py b/src/whichllm/engine/vram.py index 34cf7df..d5427be 100644 --- a/src/whichllm/engine/vram.py +++ b/src/whichllm/engine/vram.py @@ -19,12 +19,40 @@ _MOE_ATTENTION_PARAM_MULTIPLIER = 4.0 +def _effective_kv_context(model: ModelInfo, context_length: int) -> float: + """Context length that actually contributes to the KV cache. + + For sliding-window-attention (SWA) models, local-attention layers only keep + the last ``sliding_window`` tokens, so their KV footprint plateaus once the + request exceeds the window. Hybrid models interleave a fraction of global + (full-context) layers; that fraction is ``sliding_window_global_ratio``. + + Effective context blends the two layer types:: + + global_ratio * ctx + (1 - global_ratio) * min(ctx, window) + + This is only applied for architectures whose mainline runtimes honor SWA + (the fetcher leaves ``sliding_window`` ``None`` otherwise), and it can only + ever lower the estimate — so a model that does not advertise an honored + window keeps the full-context KV figure and stays conservative. + """ + window = model.sliding_window + ratio = model.sliding_window_global_ratio + if not window or window <= 0 or ratio is None: + return float(context_length) + ratio = min(max(ratio, 0.0), 1.0) + windowed = min(context_length, window) + return ratio * context_length + (1.0 - ratio) * windowed + + def estimate_kv_cache(model: ModelInfo, context_length: int) -> int: """Estimate KV cache size in bytes for a given context length. Dense models: KV ≈ 3 MB × params_b × ctx_k (FP16 K+V across all layers). MoE models: scale from active params × an empirical multiplier because attention shares across experts. + Sliding-window models cap local-layer KV at the window size (see + :func:`_effective_kv_context`). """ if model.is_moe and model.parameter_count_active: # Active-params × MoE multiplier gives a reasonable proxy for the @@ -34,7 +62,8 @@ def estimate_kv_cache(model: ModelInfo, context_length: int) -> int: else: params_b = model.parameter_count / 1e9 - ctx_k = context_length / 1024 + effective_ctx = _effective_kv_context(model, context_length) + ctx_k = effective_ctx / 1024 kv_bytes = int(params_b * ctx_k * _KV_BYTES_PER_BPARAM_PER_KCTX) return max(kv_bytes, 0) diff --git a/src/whichllm/models/fetcher.py b/src/whichllm/models/fetcher.py index 7586ac1..a81845c 100644 --- a/src/whichllm/models/fetcher.py +++ b/src/whichllm/models/fetcher.py @@ -175,6 +175,148 @@ def _resolve_moe_active_params( return None +# Sliding-window-attention (SWA) registry. We only model SWA KV-cache savings +# for architectures whose mainline runtimes actually honor interleaved SWA +# (llama.cpp's ISWA path, MLX). Each entry is (default_window_tokens, +# global_layer_ratio) where the ratio is the fraction of layers that use full +# (global) attention. Models outside this allowlist keep full-context KV so +# estimates stay conservative — notably Mistral-7B-v0.1, whose config declares +# sliding_window=4096 but whose window is ignored by mainline runtimes. +# - gemma2: alternating local/global -> 1/2 global +# - gemma3: 5 local : 1 global (sliding_window_pattern=6) -> 1/6 global +# - gpt_oss: alternating sliding/full -> 1/2 global +# - cohere2: 3 local : 1 global (sliding_window_pattern=4) -> 1/4 global +_SWA_ARCH_DEFAULTS: dict[str, tuple[int, float]] = { + "gemma2": (4096, 0.5), + "gemma3": (1024, 1.0 / 6.0), + "gpt_oss": (128, 0.5), + "cohere2": (4096, 0.25), +} + +# Map the many spellings an arch string can take (HF model_type, the +# ForCausalLM/ForConditionalGeneration class prefix, and GGUF metadata) onto a +# canonical key in _SWA_ARCH_DEFAULTS. +_SWA_ARCH_ALIASES: dict[str, str] = { + "gemma2": "gemma2", + "gemma2_text": "gemma2", + "gemma3": "gemma3", + "gemma3_text": "gemma3", + "gpt_oss": "gpt_oss", + "gptoss": "gpt_oss", + "cohere2": "cohere2", +} + +# Last-resort fallback keyed by a model-name token, for GGUF-only repos that +# expose neither an HF config nor GGUF architecture metadata. Matched only as a +# delimited token in the model name, and only when no *other* base-architecture +# family is named (to avoid mislabeling merges/finetunes). Kept deliberately +# small and precise — a false positive here would under-count VRAM. +_SWA_ID_HINTS: tuple[tuple[str, str], ...] = ( + ("gemma-3", "gemma3"), + ("gemma-2", "gemma2"), + ("gpt-oss", "gpt_oss"), + ("command-r7b", "cohere2"), +) + +# Other base families that, if named in a model id, make an id-hint match +# ambiguous (e.g. a "gemma-3 distilled from llama" merge) — bail conservatively. +_CONFLICTING_ARCH_TOKENS: tuple[str, ...] = ( + "llama", + "qwen", + "mistral", + "mixtral", + "phi", + "deepseek", + "yi", + "falcon", + "internlm", + "baichuan", + "glm", +) + + +def _swa_key_from_arch(arch: str | None) -> str | None: + """Resolve an arch string (model_type / class / gguf metadata) to a key.""" + if not arch: + return None + arch = arch.lower() + if arch in _SWA_ARCH_ALIASES: + return _SWA_ARCH_ALIASES[arch] + stripped = arch.replace("forcausallm", "").replace("forconditionalgeneration", "") + if stripped in _SWA_ARCH_ALIASES: + return _SWA_ARCH_ALIASES[stripped] + return None + + +def _swa_key_from_id(model_id: str) -> str | None: + """Boundary-aware, conflict-guarded model-id fallback (see _SWA_ID_HINTS).""" + name = model_id.split("/")[-1].lower() + for hint, key in _SWA_ID_HINTS: + if re.search(rf"(? gemma + if any(tok in name for tok in _CONFLICTING_ARCH_TOKENS if tok != family): + return None + return key + return None + + +def _swa_arch_key(config: dict, model_id: str, gguf_arch: str | None) -> str | None: + """Identify the SWA architecture key for a model, or None if not honored. + + Prefers authoritative sources — the raw HF config ``model_type``/ + ``architectures`` (not the normalized architecture string, which collapses + gemma2 and gemma3 into "gemma") and the GGUF metadata architecture — before + a narrow, conflict-guarded model-id fallback for config-less GGUF repos. + """ + model_type = config.get("model_type") + key = _swa_key_from_arch(model_type if isinstance(model_type, str) else None) + if key: + return key + + arch_list = config.get("architectures") or [] + if arch_list and isinstance(arch_list[0], str): + key = _swa_key_from_arch(arch_list[0]) + if key: + return key + + key = _swa_key_from_arch(gguf_arch) + if key: + return key + + return _swa_key_from_id(model_id) + + +def _resolve_sliding_window( + config: dict, model_id: str, gguf_arch: str | None = None +) -> tuple[int | None, float | None]: + """Resolve (sliding_window, global_ratio) for honored SWA architectures. + + Returns (None, None) for every model outside the allowlist so the KV + estimate stays at full context (conservative). + """ + # Respect an explicit opt-out before doing any work. + if config.get("use_sliding_window") is False: + return None, None + + key = _swa_arch_key(config, model_id, gguf_arch) + if key is None: + return None, None + + default_window, default_ratio = _SWA_ARCH_DEFAULTS[key] + + window = config.get("sliding_window") + if not isinstance(window, int) or window <= 0: + window = default_window + + pattern = config.get("sliding_window_pattern") + if isinstance(pattern, int) and pattern > 0: + global_ratio = 1.0 / pattern + else: + global_ratio = default_ratio + + return window, global_ratio + + def _normalize_param_count( extracted: int, model_id: str, @@ -549,6 +691,11 @@ def _parse_model(data: dict) -> ModelInfo | None: if not context_length and isinstance(gguf_meta, dict): context_length = gguf_meta.get("context_length") + gguf_arch = gguf_meta.get("architecture") if isinstance(gguf_meta, dict) else None + sliding_window, swa_global_ratio = _resolve_sliding_window( + config, model_id, gguf_arch + ) + benchmark_scores: dict[str, float] = {} eval_score = _extract_hf_eval_score(data) if eval_score is not None: @@ -570,6 +717,8 @@ def _parse_model(data: dict) -> ModelInfo | None: gguf_variants=gguf_variants, benchmark_scores=benchmark_scores, base_model=base_model, + sliding_window=sliding_window, + sliding_window_global_ratio=swa_global_ratio, ) @@ -877,6 +1026,8 @@ def models_to_dicts(models: list[ModelInfo]) -> list[dict]: ], "benchmark_scores": m.benchmark_scores, "base_model": m.base_model, + "sliding_window": m.sliding_window, + "sliding_window_global_ratio": m.sliding_window_global_ratio, } ) return result @@ -925,6 +1076,8 @@ def dicts_to_models(data: list[dict]) -> list[ModelInfo]: ], benchmark_scores=d.get("benchmark_scores", {}), base_model=base_model, + sliding_window=d.get("sliding_window"), + sliding_window_global_ratio=d.get("sliding_window_global_ratio"), ) ) return models diff --git a/src/whichllm/models/types.py b/src/whichllm/models/types.py index d8e7479..56abcba 100644 --- a/src/whichllm/models/types.py +++ b/src/whichllm/models/types.py @@ -27,6 +27,14 @@ class ModelInfo: gguf_variants: list[GGUFVariant] = field(default_factory=list) benchmark_scores: dict[str, float] = field(default_factory=dict) base_model: str | None = None # cardData.base_model + # Sliding-window-attention KV-cache modeling. Only populated for + # architectures whose mainline runtimes actually honor interleaved SWA + # (Gemma-2/3, gpt-oss, Cohere2); left None otherwise so VRAM estimates + # stay conservative. sliding_window is the local-attention window in + # tokens; sliding_window_global_ratio is the fraction of layers that use + # full (global) attention (0.0 = pure SWA, 1.0 = fully dense). + sliding_window: int | None = None + sliding_window_global_ratio: float | None = None @dataclass diff --git a/tests/test_fetcher.py b/tests/test_fetcher.py index 4df084b..3abb23f 100644 --- a/tests/test_fetcher.py +++ b/tests/test_fetcher.py @@ -398,3 +398,186 @@ def test_deepseek_v4_flash_uses_model_card_counts_over_hf_tensor_metadata(): assert parsed is not None assert parsed.parameter_count == 284_000_000_000 assert parsed.parameter_count_active == 13_000_000_000 + + +def test_parse_model_resolves_gemma3_sliding_window(): + parsed = _parse_model( + { + "id": "google/gemma-3-27b-it", + "config": { + "architectures": ["Gemma3ForConditionalGeneration"], + "model_type": "gemma3", + "sliding_window": 1024, + "sliding_window_pattern": 6, + "max_position_embeddings": 131072, + }, + "safetensors": {"total": 27_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window == 1024 + assert parsed.sliding_window_global_ratio == 1.0 / 6.0 + + +def test_parse_model_resolves_gpt_oss_sliding_window(): + parsed = _parse_model( + { + "id": "openai/gpt-oss-20b", + "config": { + "architectures": ["GptOssForCausalLM"], + "model_type": "gpt_oss", + "sliding_window": 128, + "max_position_embeddings": 131072, + }, + "safetensors": {"total": 20_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window == 128 + assert parsed.sliding_window_global_ratio == 0.5 + + +def test_parse_model_does_not_honor_mistral_sliding_window(): + """Mistral declares sliding_window in config but runtimes ignore it.""" + parsed = _parse_model( + { + "id": "mistralai/Mistral-7B-v0.1", + "config": { + "architectures": ["MistralForCausalLM"], + "model_type": "mistral", + "sliding_window": 4096, + "max_position_embeddings": 32768, + }, + "safetensors": {"total": 7_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window is None + assert parsed.sliding_window_global_ratio is None + + +def test_parse_model_respects_disabled_sliding_window(): + parsed = _parse_model( + { + "id": "google/gemma-3-4b-it", + "config": { + "architectures": ["Gemma3ForConditionalGeneration"], + "model_type": "gemma3", + "sliding_window": 1024, + "use_sliding_window": False, + }, + "safetensors": {"total": 4_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window is None + assert parsed.sliding_window_global_ratio is None + + +def test_parse_model_does_not_honor_gemma1_window(): + """Gemma-1 has no sliding window; the gemma2/gemma3 keys must not match it.""" + parsed = _parse_model( + { + "id": "google/gemma-7b", + "config": { + "architectures": ["GemmaForCausalLM"], + "model_type": "gemma", + }, + "safetensors": {"total": 7_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window is None + + +def test_parse_model_sliding_window_from_gguf_metadata(): + """GGUF-only repos: trust the authoritative GGUF architecture metadata.""" + parsed = _parse_model( + { + "id": "unsloth/gemma-3-12b-it-GGUF", + "config": {}, + "gguf": {"architecture": "gemma3"}, + "safetensors": {"total": 12_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window == 1024 + assert parsed.sliding_window_global_ratio == 1.0 / 6.0 + + +def test_parse_model_sliding_window_from_id_hint_when_no_metadata(): + """With neither config nor GGUF arch, fall back to a boundary-matched hint.""" + parsed = _parse_model( + { + "id": "TheBloke/gemma-2-9b-it-GGUF", + "config": {}, + "gguf": {}, + "safetensors": {"total": 9_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window == 4096 + assert parsed.sliding_window_global_ratio == 0.5 + + +def test_parse_model_id_hint_ignores_merge_with_other_base(): + """A merge/finetune that names another base arch must NOT get a fake window.""" + parsed = _parse_model( + { + "id": "someorg/llama3-gemma-2-distill-merge", + "config": {}, + "gguf": {}, + "safetensors": {"total": 8_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window is None + assert parsed.sliding_window_global_ratio is None + + +def test_parse_model_id_hint_requires_token_boundary(): + """A substring buried inside another token must not trigger the hint.""" + parsed = _parse_model( + { + "id": "someorg/notagemma-2x-model", + "config": {}, + "gguf": {}, + "safetensors": {"total": 8_000_000_000}, + "siblings": [], + "cardData": {}, + } + ) + assert parsed is not None + assert parsed.sliding_window is None + + +def test_models_cache_roundtrip_keeps_sliding_window(): + models = [ + ModelInfo( + id="google/gemma-3-27b-it", + family_id="gemma-3-27b", + name="gemma-3-27b-it", + parameter_count=27_000_000_000, + sliding_window=1024, + sliding_window_global_ratio=1.0 / 6.0, + ) + ] + restored = dicts_to_models(models_to_dicts(models)) + assert restored[0].sliding_window == 1024 + assert restored[0].sliding_window_global_ratio == 1.0 / 6.0 diff --git a/tests/test_vram.py b/tests/test_vram.py index 7e256e4..74b54fd 100644 --- a/tests/test_vram.py +++ b/tests/test_vram.py @@ -60,3 +60,89 @@ def test_estimate_vram_small_model(): # Should be reasonable for a tiny model assert vram > 300_000_000 assert vram < 3_000_000_000 + + +def test_kv_cache_unchanged_when_no_sliding_window(): + """Models without an honored sliding window keep the full-context KV figure. + + This is the conservative-default guarantee: the SWA change must be a no-op + for every model that does not advertise an honored window. Expected values + are pinned literals (the current 3.5 MB/B/Kctx coefficient) so the test + fails if either the formula or the coefficient drifts. + """ + dense = _make_model(7_000_000_000) + expected = { + 4096: 102_760_448, + 32768: 822_083_584, + 131072: 3_288_334_336, + } + for ctx, want in expected.items(): + assert estimate_kv_cache(dense, ctx) == want + + +def test_kv_cache_mistral_window_in_config_is_ignored(): + """A declared window is NOT honored unless the architecture is allowlisted. + + Mistral-7B-v0.1 ships sliding_window=4096 in its config but mainline + runtimes ignore it, so whichllm must stay at full-context KV. We model this + by leaving sliding_window=None on the model; the estimate must equal dense. + """ + mistral = _make_model(7_000_000_000, architecture="mistral") + dense = _make_model(7_000_000_000) + for ctx in (4096, 131072): + assert estimate_kv_cache(mistral, ctx) == estimate_kv_cache(dense, ctx) + + +def test_kv_cache_pure_swa_plateaus_beyond_window(): + """Pure sliding-window models (global_ratio=0) plateau at the window size.""" + swa = _make_model( + 7_000_000_000, sliding_window=4096, sliding_window_global_ratio=0.0 + ) + at_window = estimate_kv_cache(swa, 4096) + far_beyond = estimate_kv_cache(swa, 131072) + # KV is flat once context exceeds the window. + assert far_beyond == at_window + # And it is far below the dense estimate at the same long context. + dense = estimate_kv_cache(_make_model(7_000_000_000), 131072) + assert far_beyond < dense / 10 + + +def test_kv_cache_hybrid_grows_slower_than_dense(): + """Hybrid SWA models grow with context but much slower than dense.""" + # Gemma-3-like: 1/6 global layers, 1024-token window. + hybrid = _make_model( + 27_000_000_000, + sliding_window=1024, + sliding_window_global_ratio=1.0 / 6.0, + ) + dense = _make_model(27_000_000_000) + short = estimate_kv_cache(hybrid, 4096) + long_hybrid = estimate_kv_cache(hybrid, 131072) + long_dense = estimate_kv_cache(dense, 131072) + # Still grows with context (global layers keep scaling)... + assert long_hybrid > short + # ...but well under the dense estimate (roughly the global ratio). + assert long_hybrid < long_dense + assert long_hybrid < long_dense * 0.30 + + +def test_kv_cache_never_exceeds_dense_estimate(): + """The SWA reduction can only ever lower the estimate, never raise it.""" + for ratio in (0.0, 1.0 / 6.0, 0.25, 0.5, 1.0): + swa = _make_model( + 13_000_000_000, + sliding_window=2048, + sliding_window_global_ratio=ratio, + ) + dense = _make_model(13_000_000_000) + for ctx in (1024, 4096, 65536): + assert estimate_kv_cache(swa, ctx) <= estimate_kv_cache(dense, ctx) + + +def test_kv_cache_below_window_matches_dense(): + """When context fits inside the window, SWA and dense agree.""" + swa = _make_model( + 7_000_000_000, sliding_window=8192, sliding_window_global_ratio=0.0 + ) + dense = _make_model(7_000_000_000) + assert estimate_kv_cache(swa, 4096) == estimate_kv_cache(dense, 4096)