Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/whichllm/engine/vram.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,40 @@
_MOE_ATTENTION_PARAM_MULTIPLIER = 4.0


def _effective_kv_context(model: ModelInfo, context_length: int) -> float:
"""Context length that actually contributes to the KV cache.

For sliding-window-attention (SWA) models, local-attention layers only keep
the last ``sliding_window`` tokens, so their KV footprint plateaus once the
request exceeds the window. Hybrid models interleave a fraction of global
(full-context) layers; that fraction is ``sliding_window_global_ratio``.

Effective context blends the two layer types::

global_ratio * ctx + (1 - global_ratio) * min(ctx, window)

This is only applied for architectures whose mainline runtimes honor SWA
(the fetcher leaves ``sliding_window`` ``None`` otherwise), and it can only
ever lower the estimate — so a model that does not advertise an honored
window keeps the full-context KV figure and stays conservative.
"""
window = model.sliding_window
ratio = model.sliding_window_global_ratio
if not window or window <= 0 or ratio is None:
return float(context_length)
ratio = min(max(ratio, 0.0), 1.0)
windowed = min(context_length, window)
return ratio * context_length + (1.0 - ratio) * windowed


def estimate_kv_cache(model: ModelInfo, context_length: int) -> int:
"""Estimate KV cache size in bytes for a given context length.

Dense models: KV ≈ 3 MB × params_b × ctx_k (FP16 K+V across all layers).
MoE models: scale from active params × an empirical multiplier because
attention shares across experts.
Sliding-window models cap local-layer KV at the window size (see
:func:`_effective_kv_context`).
"""
if model.is_moe and model.parameter_count_active:
# Active-params × MoE multiplier gives a reasonable proxy for the
Expand All @@ -34,7 +62,8 @@ def estimate_kv_cache(model: ModelInfo, context_length: int) -> int:
else:
params_b = model.parameter_count / 1e9

ctx_k = context_length / 1024
effective_ctx = _effective_kv_context(model, context_length)
ctx_k = effective_ctx / 1024
kv_bytes = int(params_b * ctx_k * _KV_BYTES_PER_BPARAM_PER_KCTX)
return max(kv_bytes, 0)

Expand Down
153 changes: 153 additions & 0 deletions src/whichllm/models/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,148 @@ def _resolve_moe_active_params(
return None


# Sliding-window-attention (SWA) registry. We only model SWA KV-cache savings
# for architectures whose mainline runtimes actually honor interleaved SWA
# (llama.cpp's ISWA path, MLX). Each entry is (default_window_tokens,
# global_layer_ratio) where the ratio is the fraction of layers that use full
# (global) attention. Models outside this allowlist keep full-context KV so
# estimates stay conservative — notably Mistral-7B-v0.1, whose config declares
# sliding_window=4096 but whose window is ignored by mainline runtimes.
# - gemma2: alternating local/global -> 1/2 global
# - gemma3: 5 local : 1 global (sliding_window_pattern=6) -> 1/6 global
# - gpt_oss: alternating sliding/full -> 1/2 global
# - cohere2: 3 local : 1 global (sliding_window_pattern=4) -> 1/4 global
_SWA_ARCH_DEFAULTS: dict[str, tuple[int, float]] = {
"gemma2": (4096, 0.5),
"gemma3": (1024, 1.0 / 6.0),
"gpt_oss": (128, 0.5),
"cohere2": (4096, 0.25),
}

# Map the many spellings an arch string can take (HF model_type, the
# ForCausalLM/ForConditionalGeneration class prefix, and GGUF metadata) onto a
# canonical key in _SWA_ARCH_DEFAULTS.
_SWA_ARCH_ALIASES: dict[str, str] = {
"gemma2": "gemma2",
"gemma2_text": "gemma2",
"gemma3": "gemma3",
"gemma3_text": "gemma3",
"gpt_oss": "gpt_oss",
"gptoss": "gpt_oss",
"cohere2": "cohere2",
}

# Last-resort fallback keyed by a model-name token, for GGUF-only repos that
# expose neither an HF config nor GGUF architecture metadata. Matched only as a
# delimited token in the model name, and only when no *other* base-architecture
# family is named (to avoid mislabeling merges/finetunes). Kept deliberately
# small and precise — a false positive here would under-count VRAM.
_SWA_ID_HINTS: tuple[tuple[str, str], ...] = (
("gemma-3", "gemma3"),
("gemma-2", "gemma2"),
("gpt-oss", "gpt_oss"),
("command-r7b", "cohere2"),
)

# Other base families that, if named in a model id, make an id-hint match
# ambiguous (e.g. a "gemma-3 distilled from llama" merge) — bail conservatively.
_CONFLICTING_ARCH_TOKENS: tuple[str, ...] = (
"llama",
"qwen",
"mistral",
"mixtral",
"phi",
"deepseek",
"yi",
"falcon",
"internlm",
"baichuan",
"glm",
)


def _swa_key_from_arch(arch: str | None) -> str | None:
"""Resolve an arch string (model_type / class / gguf metadata) to a key."""
if not arch:
return None
arch = arch.lower()
if arch in _SWA_ARCH_ALIASES:
return _SWA_ARCH_ALIASES[arch]
stripped = arch.replace("forcausallm", "").replace("forconditionalgeneration", "")
if stripped in _SWA_ARCH_ALIASES:
return _SWA_ARCH_ALIASES[stripped]
return None


def _swa_key_from_id(model_id: str) -> str | None:
"""Boundary-aware, conflict-guarded model-id fallback (see _SWA_ID_HINTS)."""
name = model_id.split("/")[-1].lower()
for hint, key in _SWA_ID_HINTS:
if re.search(rf"(?<![a-z0-9]){re.escape(hint)}(?![a-z0-9])", name):
family = key.split("_")[0].rstrip("0123456789") # gemma3 -> gemma
if any(tok in name for tok in _CONFLICTING_ARCH_TOKENS if tok != family):
return None
return key
return None


def _swa_arch_key(config: dict, model_id: str, gguf_arch: str | None) -> str | None:
"""Identify the SWA architecture key for a model, or None if not honored.

Prefers authoritative sources — the raw HF config ``model_type``/
``architectures`` (not the normalized architecture string, which collapses
gemma2 and gemma3 into "gemma") and the GGUF metadata architecture — before
a narrow, conflict-guarded model-id fallback for config-less GGUF repos.
"""
model_type = config.get("model_type")
key = _swa_key_from_arch(model_type if isinstance(model_type, str) else None)
if key:
return key

arch_list = config.get("architectures") or []
if arch_list and isinstance(arch_list[0], str):
key = _swa_key_from_arch(arch_list[0])
if key:
return key

key = _swa_key_from_arch(gguf_arch)
if key:
return key

return _swa_key_from_id(model_id)


def _resolve_sliding_window(
config: dict, model_id: str, gguf_arch: str | None = None
) -> tuple[int | None, float | None]:
"""Resolve (sliding_window, global_ratio) for honored SWA architectures.

Returns (None, None) for every model outside the allowlist so the KV
estimate stays at full context (conservative).
"""
# Respect an explicit opt-out before doing any work.
if config.get("use_sliding_window") is False:
return None, None

key = _swa_arch_key(config, model_id, gguf_arch)
if key is None:
return None, None

default_window, default_ratio = _SWA_ARCH_DEFAULTS[key]

window = config.get("sliding_window")
if not isinstance(window, int) or window <= 0:
window = default_window

pattern = config.get("sliding_window_pattern")
if isinstance(pattern, int) and pattern > 0:
global_ratio = 1.0 / pattern
else:
global_ratio = default_ratio

return window, global_ratio


def _normalize_param_count(
extracted: int,
model_id: str,
Expand Down Expand Up @@ -549,6 +691,11 @@ def _parse_model(data: dict) -> ModelInfo | None:
if not context_length and isinstance(gguf_meta, dict):
context_length = gguf_meta.get("context_length")

gguf_arch = gguf_meta.get("architecture") if isinstance(gguf_meta, dict) else None
sliding_window, swa_global_ratio = _resolve_sliding_window(
config, model_id, gguf_arch
)

benchmark_scores: dict[str, float] = {}
eval_score = _extract_hf_eval_score(data)
if eval_score is not None:
Expand All @@ -570,6 +717,8 @@ def _parse_model(data: dict) -> ModelInfo | None:
gguf_variants=gguf_variants,
benchmark_scores=benchmark_scores,
base_model=base_model,
sliding_window=sliding_window,
sliding_window_global_ratio=swa_global_ratio,
)


Expand Down Expand Up @@ -877,6 +1026,8 @@ def models_to_dicts(models: list[ModelInfo]) -> list[dict]:
],
"benchmark_scores": m.benchmark_scores,
"base_model": m.base_model,
"sliding_window": m.sliding_window,
"sliding_window_global_ratio": m.sliding_window_global_ratio,
}
)
return result
Expand Down Expand Up @@ -925,6 +1076,8 @@ def dicts_to_models(data: list[dict]) -> list[ModelInfo]:
],
benchmark_scores=d.get("benchmark_scores", {}),
base_model=base_model,
sliding_window=d.get("sliding_window"),
sliding_window_global_ratio=d.get("sliding_window_global_ratio"),
)
)
return models
Expand Down
8 changes: 8 additions & 0 deletions src/whichllm/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class ModelInfo:
gguf_variants: list[GGUFVariant] = field(default_factory=list)
benchmark_scores: dict[str, float] = field(default_factory=dict)
base_model: str | None = None # cardData.base_model
# Sliding-window-attention KV-cache modeling. Only populated for
# architectures whose mainline runtimes actually honor interleaved SWA
# (Gemma-2/3, gpt-oss, Cohere2); left None otherwise so VRAM estimates
# stay conservative. sliding_window is the local-attention window in
# tokens; sliding_window_global_ratio is the fraction of layers that use
# full (global) attention (0.0 = pure SWA, 1.0 = fully dense).
sliding_window: int | None = None
sliding_window_global_ratio: float | None = None


@dataclass
Expand Down
Loading