Skip to content
73 changes: 53 additions & 20 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import asyncio
import copy
import sys
Expand Down Expand Up @@ -216,7 +217,7 @@ async def reset(
enforce_max_turns: int = -1,
# llm compressor
llm_compress_instruction: str | None = None,
llm_compress_keep_recent_ratio: float = 0.15,
llm_compress_keep_recent: int = 0,
Comment thread
bytecategory marked this conversation as resolved.
llm_compress_provider: Provider | None = None,
# truncate by turns compressor
truncate_turns: int = 1,
Expand All @@ -233,7 +234,7 @@ async def reset(
self.streaming = streaming
self.enforce_max_turns = enforce_max_turns
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent_ratio = llm_compress_keep_recent_ratio
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
Expand All @@ -248,7 +249,7 @@ async def reset(
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent_ratio=self.llm_compress_keep_recent_ratio,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
Expand Down Expand Up @@ -458,8 +459,11 @@ async def _iter_llm_responses(
self, *, include_model: bool = True
) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
messages_for_provider = getattr(
self, "_provider_messages", self.run_context.messages
)
payload = {
"contexts": self._sanitize_contexts_for_provider(self.run_context.messages),
"contexts": self._sanitize_contexts_for_provider(messages_for_provider),
"func_tool": self._func_tool_for_provider(),
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
Expand Down Expand Up @@ -580,10 +584,7 @@ def _sanitize_contexts_for_provider(
self,
contexts: list[Message] | list[dict[str, T.Any]],
) -> list[Message] | list[dict[str, T.Any]]:
modalities = self.provider.provider_config.get("modalities", None)
if (
not modalities
): # Unconfigured (None or empty list) defaults to support all modalities
if not self._should_fix_modalities_for_provider():
return contexts
sanitized_contexts, stats = sanitize_contexts_by_modalities(
contexts,
Expand All @@ -592,6 +593,12 @@ def _sanitize_contexts_for_provider(
log_context_sanitize_stats(stats)
return sanitized_contexts

def _should_fix_modalities_for_provider(self) -> bool:
modalities = self.provider.provider_config.get("modalities", None)
return (
Comment thread
bytecategory marked this conversation as resolved.
isinstance(modalities, list) and modalities
) # Empty list is treated as unconfigured
Comment thread
bytecategory marked this conversation as resolved.

def _func_tool_for_provider(self) -> ToolSet | None:
if not self.req.func_tool:
return None
Expand All @@ -604,14 +611,11 @@ def _func_tool_for_provider(self) -> ToolSet | None:
return None
return self.req.func_tool

def _simple_print_message_role(self, tag: str, messages: list):
roles = [m.role for m in messages]
n = len(roles)
if n > 10:
summary = ",".join(roles[:4]) + ",...," + ",".join(roles[-4:])
else:
summary = ",".join(roles)
logger.debug(f"{tag} messages -> [{n}] {summary}")
def _simple_print_message_role(self, tag: str = ""):
roles = []
for message in self.run_context.messages:
roles.append(message.role)
logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}")

def follow_up(
self,
Expand Down Expand Up @@ -705,13 +709,16 @@ async def step(self):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None

# Process request-time context before sending it to the provider.
# Process request-time context on a copy so the runner's canonical
# messages are never mutated. The processed result is only used for this
# provider call. Persistent compaction is owned by the conversation /
# memory layer.
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self._simple_print_message_role("[BefCompact]", self.run_context.messages)
self.run_context.messages = await self.request_context_manager.process(
self._simple_print_message_role("[BefCompact]")
self._provider_messages = await self.request_context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
self._simple_print_message_role("[AftCompact]", self.run_context.messages)
self._simple_print_message_role("[AftCompact]")

async for llm_response in self._iter_llm_responses_with_fallback():
if llm_response.is_chunk:
Expand Down Expand Up @@ -893,6 +900,25 @@ async def step(self):
parts.append(TextPart(text=llm_resp.completion_text))
if len(parts) == 0:
parts = None

# 过滤掉无效的 tool calls,确保 assistant 消息不包含无 id/name 的条目
if llm_resp.tools_call_name and llm_resp.tools_call_ids:
valid_indices = [
i for i, (name, tid) in enumerate(
zip(llm_resp.tools_call_name, llm_resp.tools_call_ids)
)
if name and tid
]
if len(valid_indices) < len(llm_resp.tools_call_name):
llm_resp.tools_call_name = [llm_resp.tools_call_name[i] for i in valid_indices]
llm_resp.tools_call_args = [llm_resp.tools_call_args[i] for i in valid_indices]
Comment thread
bytecategory marked this conversation as resolved.
llm_resp.tools_call_ids = [llm_resp.tools_call_ids[i] for i in valid_indices]
Comment thread
bytecategory marked this conversation as resolved.

# 如果过滤后没有有效的 tool calls,跳过构建 tool_calls_result
if not llm_resp.tools_call_name:
await self._complete_with_assistant_response(llm_resp)
return

tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
tool_calls=llm_resp.to_openai_to_calls_model(),
Expand Down Expand Up @@ -997,6 +1023,13 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
# 跳过无效的 tool call(name 或 id 为空)
if not func_tool_name or not func_tool_id:
logger.warning(
f"Skipping invalid tool call with name={func_tool_name!r}, id={func_tool_id!r}"
)
continue

tool_result_blocks_start = len(tool_call_result_blocks)
tool_call_streak = self._track_tool_call_streak(func_tool_name)
yield _HandleFunctionToolsResult.from_message_chain(
Expand Down
14 changes: 14 additions & 0 deletions astrbot/core/provider/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,13 @@ def to_openai_tool_calls(self) -> list[dict]:
"""Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
if not self.tools_call_name[idx]:
logger.warning(
f"Skipping tool call at index {idx} because function.name is empty/None. "
f"tool_call_id={self.tools_call_ids[idx] if idx < len(self.tools_call_ids) else 'N/A'}, "
f"arguments={tool_call_arg}"
)
continue
Comment thread
bytecategory marked this conversation as resolved.
payload = {
"id": self.tools_call_ids[idx],
"function": {
Expand All @@ -471,6 +478,13 @@ def to_openai_to_calls_model(self) -> list[ToolCall]:
"""The same as to_openai_tool_calls but return pydantic model."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
if not self.tools_call_name[idx]:
logger.warning(
f"Skipping tool call at index {idx} because function.name is empty/None. "
f"tool_call_id={self.tools_call_ids[idx] if idx < len(self.tools_call_ids) else 'N/A'}, "
f"arguments={tool_call_arg}"
)
continue
Comment thread
bytecategory marked this conversation as resolved.
Outdated
ret.append(
ToolCall(
id=self.tools_call_ids[idx],
Expand Down