Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import asyncio
import copy
import sys
Expand Down Expand Up @@ -216,7 +217,7 @@ async def reset(
enforce_max_turns: int = -1,
# llm compressor
llm_compress_instruction: str | None = None,
llm_compress_keep_recent_ratio: float = 0.15,
llm_compress_keep_recent: int = 0,
Comment thread
bytecategory marked this conversation as resolved.
llm_compress_provider: Provider | None = None,
# truncate by turns compressor
truncate_turns: int = 1,
Expand All @@ -233,7 +234,7 @@ async def reset(
self.streaming = streaming
self.enforce_max_turns = enforce_max_turns
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent_ratio = llm_compress_keep_recent_ratio
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
Expand All @@ -248,7 +249,7 @@ async def reset(
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent_ratio=self.llm_compress_keep_recent_ratio,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
Expand Down Expand Up @@ -458,8 +459,11 @@ async def _iter_llm_responses(
self, *, include_model: bool = True
) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
messages_for_provider = getattr(
self, "_provider_messages", self.run_context.messages
)
payload = {
"contexts": self._sanitize_contexts_for_provider(self.run_context.messages),
"contexts": self._sanitize_contexts_for_provider(messages_for_provider),
"func_tool": self._func_tool_for_provider(),
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
Expand Down Expand Up @@ -580,10 +584,7 @@ def _sanitize_contexts_for_provider(
self,
contexts: list[Message] | list[dict[str, T.Any]],
) -> list[Message] | list[dict[str, T.Any]]:
modalities = self.provider.provider_config.get("modalities", None)
if (
not modalities
): # Unconfigured (None or empty list) defaults to support all modalities
if not self._should_fix_modalities_for_provider():
return contexts
sanitized_contexts, stats = sanitize_contexts_by_modalities(
contexts,
Expand All @@ -592,6 +593,12 @@ def _sanitize_contexts_for_provider(
log_context_sanitize_stats(stats)
return sanitized_contexts

def _should_fix_modalities_for_provider(self) -> bool:
modalities = self.provider.provider_config.get("modalities", None)
return (
Comment thread
bytecategory marked this conversation as resolved.
isinstance(modalities, list) and modalities
) # Empty list is treated as unconfigured
Comment thread
bytecategory marked this conversation as resolved.

def _func_tool_for_provider(self) -> ToolSet | None:
if not self.req.func_tool:
return None
Expand All @@ -604,14 +611,11 @@ def _func_tool_for_provider(self) -> ToolSet | None:
return None
return self.req.func_tool

def _simple_print_message_role(self, tag: str, messages: list):
roles = [m.role for m in messages]
n = len(roles)
if n > 10:
summary = ",".join(roles[:4]) + ",...," + ",".join(roles[-4:])
else:
summary = ",".join(roles)
logger.debug(f"{tag} messages -> [{n}] {summary}")
def _simple_print_message_role(self, tag: str = ""):
roles = []
for message in self.run_context.messages:
roles.append(message.role)
logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}")

def follow_up(
self,
Expand Down Expand Up @@ -705,13 +709,16 @@ async def step(self):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None

# Process request-time context before sending it to the provider.
# Process request-time context on a copy so the runner's canonical
# messages are never mutated. The processed result is only used for this
# provider call. Persistent compaction is owned by the conversation /
# memory layer.
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self._simple_print_message_role("[BefCompact]", self.run_context.messages)
self.run_context.messages = await self.request_context_manager.process(
self._simple_print_message_role("[BefCompact]")
self._provider_messages = await self.request_context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
self._simple_print_message_role("[AftCompact]", self.run_context.messages)
self._simple_print_message_role("[AftCompact]")

async for llm_response in self._iter_llm_responses_with_fallback():
if llm_response.is_chunk:
Expand Down Expand Up @@ -893,6 +900,25 @@ async def step(self):
parts.append(TextPart(text=llm_resp.completion_text))
if len(parts) == 0:
parts = None

# 过滤掉无效的 tool calls,确保 assistant 消息不包含无 id/name 的条目
if llm_resp.tools_call_name and llm_resp.tools_call_ids:
valid_indices = [
i for i, (name, tid) in enumerate(
zip(llm_resp.tools_call_name, llm_resp.tools_call_ids)
)
if name and tid
]
if len(valid_indices) < len(llm_resp.tools_call_name):
llm_resp.tools_call_name = [llm_resp.tools_call_name[i] for i in valid_indices]
llm_resp.tools_call_args = [llm_resp.tools_call_args[i] for i in valid_indices]
Comment thread
bytecategory marked this conversation as resolved.
llm_resp.tools_call_ids = [llm_resp.tools_call_ids[i] for i in valid_indices]
Comment thread
bytecategory marked this conversation as resolved.

# 如果过滤后没有有效的 tool calls,跳过构建 tool_calls_result
if not llm_resp.tools_call_name:
await self._complete_with_assistant_response(llm_resp)
return

tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
tool_calls=llm_resp.to_openai_to_calls_model(),
Expand Down Expand Up @@ -997,6 +1023,13 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
# 跳过无效的 tool call(name 或 id 为空)
if not func_tool_name or not func_tool_id:
logger.warning(
f"Skipping invalid tool call with name={func_tool_name!r}, id={func_tool_id!r}"
)
continue

tool_result_blocks_start = len(tool_call_result_blocks)
tool_call_streak = self._track_tool_call_streak(func_tool_name)
yield _HandleFunctionToolsResult.from_message_chain(
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ async def _execute_handoff(
continue

prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {})
agent_max_step = int(prov_settings.get("max_agent_step", 30))
agent_max_step = int(prov_settings.get("max_agent_step", 114514))
stream = prov_settings.get("streaming_response", False)
llm_resp = await ctx.tool_loop_agent(
event=event,
Expand Down
30 changes: 25 additions & 5 deletions astrbot/core/computer/booters/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from python_ripgrep import search

from astrbot.api import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.computer.file_read_utils import (
detect_text_encoding,
read_local_text_range_sync,
Expand All @@ -22,10 +23,7 @@
from .base import ComputerBooter
from .shipyard_search_file_util import _truncate_long_lines

_BLOCKED_COMMAND_PATTERNS = [
" rm -rf ",
" rm -fr ",
" rm -r ",
DEFAULT_BLOCKED_COMMAND_PATTERNS = [
" mkfs",
" dd if=",
" shutdown",
Expand All @@ -39,9 +37,31 @@
]


def _get_blocked_command_patterns() -> list[str]:
"""Return user-configured blocked command patterns.

The configuration is read on each shell execution so dashboard changes take
effect without recreating the local booter. If the config cannot be loaded
or the value has an unexpected type, fall back to the built-in defaults.
"""
try:
computer_config = AstrBotConfig().get("computer", {})
patterns = computer_config.get("blocked_command_patterns")
except Exception as e:
logger.warning(
f"Failed to load computer.blocked_command_patterns, using defaults: {e}"
)
return DEFAULT_BLOCKED_COMMAND_PATTERNS

if not isinstance(patterns, list):
return DEFAULT_BLOCKED_COMMAND_PATTERNS

return [str(pattern).lower() for pattern in patterns if str(pattern)]


def _is_safe_command(command: str) -> bool:
cmd = f" {command.strip().lower()} "
return not any(pat in cmd for pat in _BLOCKED_COMMAND_PATTERNS)
return not any(pat in cmd for pat in _get_blocked_command_patterns())


def _decode_bytes_with_fallback(
Expand Down
Loading