Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 67 additions & 6 deletions astrbot/builtin_stars/astrbot/group_chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.agent.message import TextPart
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.utils.image_caption_cache import (
image_caption_cache,
resolve_image_caption_cache_ttl,
)

"""
Group chat context awareness.
Expand Down Expand Up @@ -67,6 +71,9 @@ def cfg(self, event: AstrMessageEvent):
"image_caption": image_caption,
"image_caption_prompt": image_caption_prompt,
"image_caption_provider_id": image_caption_provider_id,
"image_caption_cache_ttl": resolve_image_caption_cache_ttl(
cfg.get("provider_settings", {})
),
"enable_active_reply": enable_active_reply,
"ar_method": ar_method,
"ar_possibility": ar_possibility,
Expand All @@ -79,17 +86,46 @@ async def get_image_caption(
image_url: str,
image_caption_provider_id: str,
image_caption_prompt: str,
cache_ttl: int = 0,
) -> str:
if not image_caption_provider_id:
provider = self.context.get_using_provider()
else:
provider = self.context.get_provider_by_id(image_caption_provider_id)
if not provider:
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
raise Exception(
f"Provider `{image_caption_provider_id}` was not found."
)

if not isinstance(provider, Provider):
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
response = await provider.text_chat(
raise Exception(
f"Provider type is invalid for image captioning: {type(provider)}."
)
provider_id = _resolve_provider_cache_identity(
provider,
configured_provider_id=image_caption_provider_id,
)

return await image_caption_cache.get_or_create(
provider_id=provider_id,
prompt=image_caption_prompt,
image_urls=[image_url],
ttl_seconds=cache_ttl,
caption_factory=lambda: self._fetch_image_caption(
provider,
image_caption_prompt,
image_url,
),
)

async def _fetch_image_caption(
self,
provider: Provider,
prompt: str,
image_url: str,
) -> str:
response = await provider.text_chat(
prompt=prompt,
session_id=uuid.uuid4().hex,
image_urls=[image_url],
persist=False,
Expand Down Expand Up @@ -195,15 +231,16 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
try:
url = comp.url if comp.url else comp.file
if not url:
raise Exception("图片 URL 为空")
raise Exception("Image URL is empty.")
caption = await self.get_image_caption(
url,
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
cfg["image_caption_cache_ttl"],
)
parts.append(f" [Image: {caption}]")
except Exception as e:
logger.error(f"获取图片描述失败: {e}")
logger.error(f"Failed to get image caption: {e}")
else:
parts.append(" [Image]")
elif isinstance(comp, At):
Expand All @@ -212,7 +249,7 @@ async def _format_message(self, event: AstrMessageEvent, cfg: dict) -> str:
"all",
)
if is_at_self:
parts.insert(1, "⚠️[DIRECTED AT YOU] ")
parts.insert(1, "[DIRECTED AT YOU] ")
parts.append(f" [At: {comp.name}]")

return "".join(parts)
Expand All @@ -239,3 +276,27 @@ def _trim_left(

def _format_group_history_block(records: list[str]) -> str:
return GROUP_HISTORY_HEADER + "\n".join(records) + GROUP_HISTORY_FOOTER


def _resolve_provider_cache_identity(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
provider: Provider,
configured_provider_id: str,
) -> str:
if configured_provider_id:
return configured_provider_id

provider_config = provider.provider_config or {}
provider_id = provider_config.get("id", "")
if isinstance(provider_id, str) and provider_id:
return provider_id

provider_type = provider_config.get("type", "")
model = provider.get_model()
return ":".join(
[
provider.__class__.__module__,
provider.__class__.__qualname__,
"" if provider_type is None else str(provider_type),
"" if model is None else str(model),
]
)
50 changes: 43 additions & 7 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@
get_astrbot_workspaces_path,
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.image_caption_cache import (
image_caption_cache,
resolve_image_caption_cache_ttl,
)
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.media_utils import (
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
Expand Down Expand Up @@ -583,6 +587,7 @@ async def _request_img_caption(
cfg: dict,
image_urls: list[str],
plugin_context: Context,
prompt: str | None = None,
) -> str:
prov = plugin_context.get_provider_by_id(provider_id)
if prov is None:
Expand All @@ -594,16 +599,27 @@ async def _request_img_caption(
f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}.",
)

img_cap_prompt = cfg.get(
img_cap_prompt = prompt or cfg.get(
"image_caption_prompt",
"Please describe the image.",
)
cache_ttl = resolve_image_caption_cache_ttl(cfg)
logger.debug("Processing image caption with provider: %s", provider_id)
llm_resp = await prov.text_chat(

async def _caption_factory() -> str:
llm_resp = await prov.text_chat(
prompt=img_cap_prompt,
image_urls=image_urls,
)
return llm_resp.completion_text

return await image_caption_cache.get_or_create(
provider_id=provider_id,
prompt=img_cap_prompt,
image_urls=image_urls,
ttl_seconds=cache_ttl,
caption_factory=_caption_factory,
)
return llm_resp.completion_text


async def _ensure_img_caption(
Expand Down Expand Up @@ -808,13 +824,21 @@ async def _process_quote_message(
)
if path and _is_generated_compressed_image_path(path, compress_path):
event.track_temporary_local_file(compress_path)
llm_resp = await prov.text_chat(
caption = await _request_img_caption(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
prov.provider_config.get("id", img_cap_prov_id or ""),
{
"image_caption_prompt": "Please describe the image content.",
"image_caption_cache_ttl": resolve_image_caption_cache_ttl(
config.provider_settings if config else None
),
},
[compress_path],
_QuotedImageCaptionContext(prov),
prompt="Please describe the image content.",
image_urls=[compress_path],
)
if llm_resp.completion_text:
if caption:
content_parts.append(
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
f"[Image Caption in quoted message]: {caption}"
)
else:
logger.warning("No provider found for image captioning in quote.")
Expand All @@ -836,6 +860,18 @@ async def _process_quote_message(
req.extra_user_content_parts.append(TextPart(text=quoted_text))


class _QuotedImageCaptionContext:
def __init__(self, provider: Provider) -> None:
self._provider = provider
def get_provider_by_id(self, provider_id: str) -> Provider:
wrapped_id = getattr(self._provider, "id", None)
if wrapped_id is not None and provider_id != wrapped_id:
raise ValueError(
f"Requested provider_id '{provider_id}' does not match "
f"wrapped provider id '{wrapped_id}'."
)
return self._provider
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated

def _append_system_reminders(
event: AstrMessageEvent,
req: ProviderRequest,
Expand Down
6 changes: 6 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"fallback_chat_models": [],
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"image_caption_cache_ttl": 600,
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
"wake_prefix": "",
"web_search": False,
Expand Down Expand Up @@ -3193,6 +3194,11 @@
"description": "图片转述提示词",
"type": "text",
},
"provider_settings.image_caption_cache_ttl": {
"description": "图片转述缓存时长(秒)",
"type": "int",
"hint": "在缓存时间内再次收到相同图片时,直接复用已缓存的视觉识别结果;设为 0 表示禁用缓存",
},
},
"condition": {
"provider_settings.enable": True,
Expand Down
Loading