Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions astrbot/builtin_stars/astrbot/group_chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.agent.message import TextPart
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.utils.image_caption_cache import (
image_caption_cache,
resolve_image_caption_cache_ttl,
)

"""
Group chat context awareness.
Expand Down Expand Up @@ -67,6 +71,9 @@ def cfg(self, event: AstrMessageEvent):
"image_caption": image_caption,
"image_caption_prompt": image_caption_prompt,
"image_caption_provider_id": image_caption_provider_id,
"image_caption_cache_ttl": resolve_image_caption_cache_ttl(
cfg.get("provider_settings", {})
),
"enable_active_reply": enable_active_reply,
"ar_method": ar_method,
"ar_possibility": ar_possibility,
Expand All @@ -79,22 +86,44 @@ async def get_image_caption(
image_url: str,
image_caption_provider_id: str,
image_caption_prompt: str,
cache_ttl: int = 0,
) -> str:
if not image_caption_provider_id:
provider = self.context.get_using_provider()
provider_id = (
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
provider.provider_config.get("id", "")
if isinstance(provider, Provider)
else ""
)
else:
provider = self.context.get_provider_by_id(image_caption_provider_id)
provider_id = image_caption_provider_id
if not provider:
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
raise Exception(
f"Provider `{image_caption_provider_id}` was not found."
)

if not isinstance(provider, Provider):
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
response = await provider.text_chat(
raise Exception(
f"Provider type is invalid for image captioning: {type(provider)}."
)

async def _caption_factory() -> str:
response = await provider.text_chat(
prompt=image_caption_prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
)
return response.completion_text

return await image_caption_cache.get_or_create(
provider_id=provider_id,
prompt=image_caption_prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
ttl_seconds=cache_ttl,
caption_factory=_caption_factory,
)
return response.completion_text

async def need_active_reply(self, event: AstrMessageEvent) -> bool:
cfg = self.cfg(event)
Expand Down Expand Up @@ -195,15 +224,16 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
try:
url = comp.url if comp.url else comp.file
if not url:
raise Exception("图片 URL 为空")
raise Exception("Image URL is empty.")
caption = await self.get_image_caption(
url,
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
cfg["image_caption_cache_ttl"],
)
parts.append(f" [Image: {caption}]")
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
logger.error(f"Failed to get image caption: {e}")
else:
parts.append(" [Image]")
elif isinstance(comp, At):
Expand All @@ -212,7 +242,7 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
"all",
)
if is_at_self:
parts.insert(1, "⚠️[DIRECTED AT YOU] ")
parts.insert(1, "[DIRECTED AT YOU] ")
parts.append(f" [At: {comp.name}]")

return "".join(parts)
Expand Down
46 changes: 39 additions & 7 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@
get_astrbot_workspaces_path,
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.image_caption_cache import (
image_caption_cache,
resolve_image_caption_cache_ttl,
)
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.media_utils import (
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
Expand Down Expand Up @@ -583,6 +587,7 @@ async def _request_img_caption(
cfg: dict,
image_urls: list[str],
plugin_context: Context,
prompt: str | None = None,
) -> str:
prov = plugin_context.get_provider_by_id(provider_id)
if prov is None:
Expand All @@ -594,16 +599,27 @@ async def _request_img_caption(
f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}.",
)

img_cap_prompt = cfg.get(
img_cap_prompt = prompt or cfg.get(
"image_caption_prompt",
"Please describe the image.",
)
cache_ttl = resolve_image_caption_cache_ttl(cfg)
logger.debug("Processing image caption with provider: %s", provider_id)
llm_resp = await prov.text_chat(

async def _caption_factory() -> str:
llm_resp = await prov.text_chat(
prompt=img_cap_prompt,
image_urls=image_urls,
)
return llm_resp.completion_text

return await image_caption_cache.get_or_create(
provider_id=provider_id,
prompt=img_cap_prompt,
image_urls=image_urls,
ttl_seconds=cache_ttl,
caption_factory=_caption_factory,
)
return llm_resp.completion_text


async def _ensure_img_caption(
Expand Down Expand Up @@ -808,13 +824,21 @@ async def _process_quote_message(
)
if path and _is_generated_compressed_image_path(path, compress_path):
event.track_temporary_local_file(compress_path)
llm_resp = await prov.text_chat(
caption = await _request_img_caption(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
prov.provider_config.get("id", img_cap_prov_id or ""),
{
"image_caption_prompt": "Please describe the image content.",
"image_caption_cache_ttl": resolve_image_caption_cache_ttl(
config.provider_settings if config else None
),
},
[compress_path],
_QuotedImageCaptionContext(prov),
prompt="Please describe the image content.",
image_urls=[compress_path],
)
if llm_resp.completion_text:
if caption:
content_parts.append(
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
f"[Image Caption in quoted message]: {caption}"
)
else:
logger.warning("No provider found for image captioning in quote.")
Expand All @@ -836,6 +860,14 @@ async def _process_quote_message(
req.extra_user_content_parts.append(TextPart(text=quoted_text))


class _QuotedImageCaptionContext:
def __init__(self, provider: Provider) -> None:
self._provider = provider

def get_provider_by_id(self, provider_id: str) -> Provider:
return self._provider
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated


def _append_system_reminders(
event: AstrMessageEvent,
req: ProviderRequest,
Expand Down
6 changes: 6 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"fallback_chat_models": [],
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"image_caption_cache_ttl": 600,
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
"wake_prefix": "",
"web_search": False,
Expand Down Expand Up @@ -3193,6 +3194,11 @@
"description": "图片转述提示词",
"type": "text",
},
"provider_settings.image_caption_cache_ttl": {
"description": "图片转述缓存时长(秒)",
"type": "int",
"hint": "在缓存时间内再次收到相同图片时,直接复用已缓存的视觉识别结果;设为 0 表示禁用缓存",
},
},
"condition": {
"provider_settings.enable": True,
Expand Down
169 changes: 169 additions & 0 deletions astrbot/core/utils/image_caption_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from __future__ import annotations

import asyncio
import base64
import hashlib
import time
from dataclasses import dataclass
from pathlib import Path
from urllib.parse import unquote, urlparse

from astrbot.core import logger

DEFAULT_IMAGE_CAPTION_CACHE_TTL = 600


def resolve_image_caption_cache_ttl(config: dict | None) -> int:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
if not isinstance(config, dict):
return DEFAULT_IMAGE_CAPTION_CACHE_TTL

ttl = config.get(
"image_caption_cache_ttl",
DEFAULT_IMAGE_CAPTION_CACHE_TTL,
)
if isinstance(ttl, bool):
return DEFAULT_IMAGE_CAPTION_CACHE_TTL
try:
return max(int(ttl), 0)
except (TypeError, ValueError):
return DEFAULT_IMAGE_CAPTION_CACHE_TTL


@dataclass(slots=True)
class _ImageCaptionCacheEntry:
caption: str
expires_at: float


class ImageCaptionCache:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
def __init__(self) -> None:
self._entries: dict[str, _ImageCaptionCacheEntry] = {}
self._locks: dict[str, asyncio.Lock] = {}

def clear(self) -> None:
self._entries.clear()
self._locks.clear()

async def get_or_create(
self,
*,
provider_id: str,
prompt: str,
image_urls: list[str],
ttl_seconds: int,
caption_factory,
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
) -> str:
if ttl_seconds <= 0:
return await caption_factory()

cache_key = await self._build_cache_key(
provider_id=provider_id,
prompt=prompt,
image_urls=image_urls,
)
cached_caption = self._get(cache_key)
if cached_caption is not None:
logger.debug(
"Using cached image caption. provider=%s",
provider_id or "<default>",
)
return cached_caption

lock = self._locks.setdefault(cache_key, asyncio.Lock())
async with lock:
cached_caption = self._get(cache_key)
if cached_caption is not None:
logger.debug(
"Using cached image caption after lock wait. provider=%s",
provider_id or "<default>",
)
return cached_caption

Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
caption = await caption_factory()
self._entries[cache_key] = _ImageCaptionCacheEntry(
caption=caption,
expires_at=time.monotonic() + ttl_seconds,
)
self._cleanup_expired_entries()
return caption
Comment thread
FloranceYeh marked this conversation as resolved.

def _get(self, cache_key: str) -> str | None:
entry = self._entries.get(cache_key)
if entry is None:
return None
if entry.expires_at <= time.monotonic():
self._entries.pop(cache_key, None)
return None
return entry.caption

def _cleanup_expired_entries(self) -> None:
now = time.monotonic()
expired_keys = [
key for key, entry in self._entries.items() if entry.expires_at <= now
]
for key in expired_keys:
self._entries.pop(key, None)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.

async def _build_cache_key(
self,
*,
provider_id: str,
prompt: str,
image_urls: list[str],
) -> str:
image_fingerprints = []
for image_url in image_urls:
image_fingerprints.append(await self._fingerprint_image(image_url))

joined = "\n".join([provider_id, prompt, *image_fingerprints])
return hashlib.sha256(joined.encode("utf-8")).hexdigest()

async def _fingerprint_image(self, image_url: str) -> str:
if image_url.startswith("base64://"):
raw_base64 = image_url.removeprefix("base64://")
try:
image_bytes = base64.b64decode(raw_base64)
except Exception:
return f"ref:{image_url}"
return self._hash_bytes(image_bytes)

if image_url.startswith("data:image"):
try:
_, encoded = image_url.split(",", 1)
image_bytes = base64.b64decode(encoded)
except Exception:
return f"ref:{image_url}"
return self._hash_bytes(image_bytes)

if image_url.startswith(("http://", "https://")):
return f"url:{image_url}"

local_path = self._to_local_path(image_url)
if local_path and local_path.is_file():
image_bytes = await asyncio.to_thread(local_path.read_bytes)
return self._hash_bytes(image_bytes)

return f"ref:{image_url}"

def _to_local_path(self, image_url: str) -> Path | None:
if image_url.startswith("file://"):
parsed = urlparse(image_url)
parsed_path = unquote(parsed.path)
if (
parsed_path.startswith("/")
and len(parsed_path) >= 3
and parsed_path[2] == ":"
):
parsed_path = parsed_path[1:]
return Path(parsed_path)

if image_url.startswith(("http://", "https://", "base64://", "data:image")):
return None

return Path(image_url)

def _hash_bytes(self, payload: bytes) -> str:
return hashlib.sha256(payload).hexdigest()


image_caption_cache = ImageCaptionCache()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Auto-generated MDI subset – 271 icons */
/* Auto-generated MDI subset – 272 icons */
/* Do not edit manually. Run: pnpm run subset-icons */

@font-face {
Expand Down Expand Up @@ -464,6 +464,10 @@
content: "\F1036";
}

.mdi-file-search-outline::before {
content: "\F0C7D";
}

.mdi-file-upload::before {
content: "\F0A4D";
}
Expand Down
Binary file not shown.
Binary file not shown.
Loading