Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions backend/tests/unit/test_prompt_cache_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,112 @@ def __init__(self, **kwargs):
assert "prompt_cache_key" not in mkw, f"Client {call.get('model')} should not have prompt_cache_key"


def _load_clients_namespace(captured_calls, byok_key=None):
"""Exec clients.py in an isolated namespace with fake providers, returning the namespace.

captured_calls collects every ChatOpenAI(**kwargs) construction so tests can assert the
runtime kwargs (extra_body / prompt_cache_retention) per model.
"""

class FakeChatOpenAI:
def __init__(self, **kwargs):
self.kwargs = kwargs
captured_calls.append(kwargs)

def bind(self, **kwargs):
self.bound = kwargs
return self

class FakeOpenAIEmbeddings:
def __init__(self, **kwargs):
pass

fake_tiktoken = _stub_module("tiktoken_fake2")
fake_tiktoken.encoding_for_model = MagicMock(return_value=MagicMock())
fake_anthropic = _stub_module("anthropic_fake2")
fake_anthropic.AsyncAnthropic = MagicMock

source = (BACKEND_DIR / "utils" / "llm" / "clients.py").read_text(encoding="utf-8")
for line in [
"from langchain_core.language_models import BaseChatModel",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings",
"from langchain_google_genai import ChatGoogleGenerativeAI",
"import tiktoken",
"import anthropic",
"from langchain_core.output_parsers import PydanticOutputParser",
"from models.structured import Structured",
"from utils.byok import get_byok_key",
"from utils.llm.usage_tracker import get_usage_callback",
]:
source = source.replace(line, "")

ns = {
"os": os,
"BaseChatModel": object,
"ChatOpenAI": FakeChatOpenAI,
"ChatGoogleGenerativeAI": FakeChatOpenAI,
"OpenAIEmbeddings": FakeOpenAIEmbeddings,
"tiktoken": fake_tiktoken,
"anthropic": fake_anthropic,
"PydanticOutputParser": MagicMock(),
"Structured": MagicMock(),
"get_byok_key": MagicMock(return_value=byok_key),
"get_usage_callback": MagicMock(return_value=[]),
"List": list,
}
exec(source, ns)
return ns


def test_renamed_gpt5_model_still_gets_cache_features():
"""
Capability-based gating (not exact model names): a different gpt-5 family model that is
not the hardcoded 'gpt-5.1'/'gpt-5.4' must still receive prompt_cache_retention and be
eligible for prompt_cache_key routing.
"""
captured_calls = []
ns = _load_clients_namespace(captured_calls)

# A hypothetical future/renamed gpt-5 family model.
new_model = "gpt-5.9-turbo"

# Retention must be applied for the new gpt-5 model via the default factory.
ns["_get_or_create_openai_llm"](new_model)
new_calls = [c for c in captured_calls if c.get("model") == new_model]
assert new_calls, "Factory should have constructed the new gpt-5 model client"
for call in new_calls:
assert (
call.get("extra_body", {}).get("prompt_cache_retention") == "24h"
), f"Renamed gpt-5 model should get 24h retention: {call}"

# prompt_cache_key routing must be eligible for the new model.
assert ns["_supports_prompt_cache_key"](new_model), "Renamed gpt-5 model should support prompt_cache_key"
assert ns["_supports_cache_retention"](new_model), "Renamed gpt-5 model should support cache retention"


def test_non_cache_capable_model_is_unchanged():
"""A non-cache-capable model (e.g. a Gemini name) must not get cache retention/routing."""
captured_calls = []
ns = _load_clients_namespace(captured_calls)

assert not ns["_supports_prompt_cache_key"]("gemini-2.5-flash-lite")
assert not ns["_supports_cache_retention"]("gemini-2.5-flash-lite")
# gpt-4.1-mini supports routing but not 24h retention.
assert ns["_supports_prompt_cache_key"]("gpt-4.1-mini")
assert not ns["_supports_cache_retention"]("gpt-4.1-mini")


def test_get_llm_binds_cache_key_for_cache_capable_models():
"""get_llm must bind prompt_cache_key for cache-capable models beyond the old hardcoded set."""
captured_calls = []
ns = _load_clients_namespace(captured_calls)

# Force the active profile to route chat_responses to a gpt-5 model so get_llm resolves it.
ns["_active_profile"]["chat_responses"] = ("gpt-5.4", "openai")
llm = ns["get_llm"]("chat_responses", cache_key="omi-chat")
assert getattr(llm, "bound", {}).get("prompt_cache_key") == "omi-chat", "get_llm should bind prompt_cache_key"


# ---------------------------------------------------------------------------
# Tests: Tool list construction in execute functions
# ---------------------------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion backend/tests/unit/test_prompt_cache_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def test_qos_cache_key_in_clients():
"""Omi QoS get_llm() should support cache_key parameter for prompt cache routing."""
source = _read_clients_source()
assert "cache_key" in source, "clients.py get_llm() should accept cache_key parameter"
assert "_CACHE_KEY_MODELS" in source, "clients.py should define _CACHE_KEY_MODELS for model-safe cache key handling"
assert (
"_supports_prompt_cache_key" in source
), "clients.py should gate prompt_cache_key by capability (_supports_prompt_cache_key)"


def test_qos_medium_tier_uses_extra_body_for_cache_retention():
Expand Down
29 changes: 20 additions & 9 deletions backend/tests/unit/test_prompt_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,21 +330,32 @@ def _read_clients_source():
clients_path = Path(__file__).resolve().parent.parent.parent / "utils" / "llm" / "clients.py"
return clients_path.read_text(encoding="utf-8")

def test_qos_gpt51_has_cache_retention(self):
"""QoS _get_or_create_openai_llm must set prompt_cache_retention=24h for gpt-5.1."""
def test_qos_openai_llm_gates_retention_by_capability(self):
"""_get_or_create_openai_llm must set prompt_cache_retention=24h via a capability check."""
source = self._read_clients_source()
match = re.search(
r"_get_or_create_openai_llm.*?gpt-5\.1.*?prompt_cache_retention.*?24h",
r"_get_or_create_openai_llm.*?_supports_cache_retention\(.*?prompt_cache_retention.*?24h",
source,
re.DOTALL,
)
assert match, "_get_or_create_openai_llm should set prompt_cache_retention='24h' for gpt-5.1"
assert (
match
), "_get_or_create_openai_llm should gate prompt_cache_retention='24h' by _supports_cache_retention()"

def test_qos_tier_medium_gets_cache_retention(self):
"""Omi QoS tier medium (gpt-5.1) must set prompt_cache_retention=24h via _get_or_create_openai_llm."""
source = self._read_clients_source()
match = re.search(r'_get_or_create_openai_llm.*?gpt-5\.1.*?prompt_cache_retention.*?24h', source, re.DOTALL)
assert match, "QoS _get_or_create_openai_llm should set prompt_cache_retention='24h' for gpt-5.1"
def test_gpt51_is_cache_retention_capable(self):
"""gpt-5.1 must still be recognized as 24h-retention capable after the capability refactor."""
from pathlib import Path

clients_path = Path(__file__).resolve().parent.parent.parent / "utils" / "llm" / "clients.py"
# Extract the prefix tuple and the predicate directly from source to avoid importing
# clients.py (which pulls heavy provider SDKs not available in the unit-test env).
source = clients_path.read_text(encoding="utf-8")
m = re.search(r"_CACHE_RETENTION_MODEL_PREFIXES\s*=\s*\(([^)]*)\)", source)
assert m, "clients.py should define _CACHE_RETENTION_MODEL_PREFIXES"
prefixes = tuple(p.strip().strip("'\"") for p in m.group(1).split(",") if p.strip())
assert "gpt-5.1".startswith(prefixes), f"gpt-5.1 should match a retention-capable prefix in {prefixes}"
# A renamed gpt-5 family model should also be covered (the whole point of the refactor).
assert "gpt-5.4".startswith(prefixes), "gpt-5.4 should be retention-capable"

def test_cache_retention_not_in_model_kwargs(self):
"""prompt_cache_retention must NOT be in model_kwargs (SDK rejects it there)."""
Expand Down
36 changes: 31 additions & 5 deletions backend/utils/llm/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _create_byok_client(
) -> Optional[ChatOpenAI]:
"""Create a ChatOpenAI using the user's BYOK key. Returns None if BYOK not supported for this provider."""
kwargs: Dict[str, Any] = {'callbacks': [_usage_callback], 'request_timeout': 120, 'max_retries': 1}
if model == 'gpt-5.1':
if _supports_cache_retention(model):
kwargs['extra_body'] = {"prompt_cache_retention": "24h"}
if streaming:
kwargs['streaming'] = True
Expand Down Expand Up @@ -397,8 +397,34 @@ def get_openai_chat(model: str, **kwargs) -> ChatOpenAI:
'wrapped_analysis': 0.7,
}

# Models that support OpenAI prompt caching (prompt_cache_key routing).
_CACHE_KEY_MODELS = {'gpt-5.4', 'gpt-5.4-mini'}
# Prompt-cache capability detection.
#
# OpenAI prompt caching is a capability of whole model families, not specific point
# releases. Gating on exact names (e.g. {'gpt-5.4', 'gpt-5.4-mini'}) silently breaks the
# moment a model is renamed or a new family member ships, so we detect by family prefix
# instead.
#
# prompt_cache_key — request routing for the prefix cache. Supported by the gpt-4o,
# gpt-4.1, gpt-5.x and o-series families.
# prompt_cache_retention='24h' — extended (24h) cache retention. Supported by the
# gpt-5.x and o-series families.

# Family prefixes whose models support OpenAI prompt caching (prompt_cache_key routing).
_CACHE_KEY_MODEL_PREFIXES = ('gpt-5', 'gpt-4.1', 'gpt-4o', 'o1', 'o3', 'o4')

# Family prefixes whose models support 24h prompt-cache retention.
_CACHE_RETENTION_MODEL_PREFIXES = ('gpt-5', 'o1', 'o3', 'o4')
Comment on lines +413 to +416

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 O-series prefixes remain individually enumerated

The o-series entries (o1, o3, o4) are still listed one-by-one — the same pattern this PR correctly fixes for the gpt-5 family. OpenAI skipped o2 entirely and has been shipping new o-series models (o1 → o3 → o4) at a steady pace; a future o5 or o6 model would silently receive neither prompt_cache_key routing nor prompt_cache_retention until someone manually adds the prefix. Consolidating to a single 'o' prefix may be too broad (non-OpenAI models), but using a narrow shared root like 'o1-'/'o3-' won't help either. A pragmatic middle ground would be to add a comment flagging this and pairing each new o-series release with a prefix update, or to derive the o-series check from a small set such as ('gpt-5', 'gpt-4.1', 'gpt-4o', 'o') with an additional digit guard.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!



def _supports_prompt_cache_key(model: str) -> bool:
"""Whether a model supports OpenAI prompt-cache routing (prompt_cache_key)."""
return bool(model) and model.startswith(_CACHE_KEY_MODEL_PREFIXES)


def _supports_cache_retention(model: str) -> bool:
"""Whether a model supports 24h OpenAI prompt-cache retention."""
return bool(model) and model.startswith(_CACHE_RETENTION_MODEL_PREFIXES)


# Features that call .with_structured_output() — logged when resolving to Gemini for compat monitoring.
_STRUCTURED_OUTPUT_FEATURES = {
Expand Down Expand Up @@ -463,7 +489,7 @@ def _get_or_create_openai_llm(model_name: str, streaming: bool = False) -> ChatO
'request_timeout': 120,
'max_retries': 1,
}
if model_name == 'gpt-5.1':
if _supports_cache_retention(model_name):
kwargs['extra_body'] = {"prompt_cache_retention": "24h"}
if streaming:
kwargs['streaming'] = True
Expand Down Expand Up @@ -642,7 +668,7 @@ def get_llm(feature: str, streaming: bool = False, cache_key: Optional[str] = No
else:
result = _get_default_client(model, provider, streaming, feature)

if cache_key and model in _CACHE_KEY_MODELS:
if cache_key and _supports_prompt_cache_key(model):
return result.bind(prompt_cache_key=cache_key)
return result

Expand Down