diff --git a/aworld/core/common.py b/aworld/core/common.py index 6c35f8c79..5b442cd4c 100644 --- a/aworld/core/common.py +++ b/aworld/core/common.py @@ -6,7 +6,6 @@ from typing import Dict, Any, Optional, Union, List, Literal from enum import Enum - from aworld.config import ConfigDict Config = Union[Dict[str, Any], ConfigDict, BaseModel] @@ -96,15 +95,18 @@ class TaskItem(BaseModel): params: Optional[Dict[str, Any]] = {} policy_info: Optional[Any] = None + class CallbackItem(BaseModel): data: Any node_id: str = None actions: List[ActionModel] = [] + class CallbackActionType(str, Enum): BYPASS = "bypass" OVERRIDE = "override" + class CallbackResult(BaseModel): success: bool = False result_data: Any = None @@ -124,7 +126,7 @@ class StreamingMode(enum.Enum): ALL = 'all' -class TaskStatusValue: +class TaskStatus(str): """Task status constants.""" INIT = 'init' RUNNING = 'running' @@ -133,5 +135,10 @@ class TaskStatusValue: CANCELLED = 'cancelled' INTERRUPTED = 'interrupted' TIMEOUT = 'timeout' + DISABLED = 'disabled' + -TaskStatus = Literal['init', 'running', 'success', 'failed', 'cancelled', 'interrupted', 'timeout'] +class TaskTypeValue: + """Task type constants.""" + INSTANT = 'instant' + SCHEDULED = 'scheduled' diff --git a/aworld/core/context/amni/__init__.py b/aworld/core/context/amni/__init__.py index 9bb6614e1..f00f537b0 100644 --- a/aworld/core/context/amni/__init__.py +++ b/aworld/core/context/amni/__init__.py @@ -9,7 +9,6 @@ from aworld import trace from aworld.config import AgentConfig, AgentMemoryConfig -from aworld.core.common import TaskStatus # lazy import from aworld.core.context.base import Context from aworld.dataset.types import TrajectoryItem diff --git a/aworld/core/context/amni/contexts.py b/aworld/core/context/amni/contexts.py index ebd4c5320..d70b03429 100644 --- a/aworld/core/context/amni/contexts.py +++ b/aworld/core/context/amni/contexts.py @@ -20,7 +20,7 @@ from .state import TaskInput, Summary from .utils import jsonplus from .worksapces import workspace_repo -from ...task import TaskStatusValue +from ...task import TaskStatus class ContextManager(BaseModel): diff --git a/aworld/core/context/amni/state/task_state.py b/aworld/core/context/amni/state/task_state.py index 6166d6175..ec2c03162 100644 --- a/aworld/core/context/amni/state/task_state.py +++ b/aworld/core/context/amni/state/task_state.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from pydantic import field_validator -from aworld.core.task import TaskStatus +from aworld.core.common import TaskStatus from .agent_state import ApplicationAgentState from .common import WorkingState, TaskInput, TaskOutput from ..utils.modelplus import from_dict_to_memory_message diff --git a/aworld/core/context/base.py b/aworld/core/context/base.py index 1d9257ebb..79deacb2a 100644 --- a/aworld/core/context/base.py +++ b/aworld/core/context/base.py @@ -15,7 +15,7 @@ from aworld.utils.common import nest_dict_counter if TYPE_CHECKING: - from aworld.core.task import Task, TaskResponse, TaskStatus, TaskStatusValue + from aworld.core.task import Task, TaskResponse, TaskStatus from aworld.events.manager import EventManager from aworld.core.agent import BaseAgent from aworld.core.context.amni import AgentContextConfig @@ -783,8 +783,8 @@ async def snapshot(self): return checkpoint async def get_task_status(self): - from aworld.core.common import TaskStatusValue - return TaskStatusValue.SUCCESS + from aworld.core.common import TaskStatus + return TaskStatus.SUCCESS async def update_task_status(self, task_id: str, status: 'TaskStatus'): pass diff --git a/aworld/core/event/message_future.py b/aworld/core/event/message_future.py index 730562412..44ff0ccee 100644 --- a/aworld/core/event/message_future.py +++ b/aworld/core/event/message_future.py @@ -84,10 +84,10 @@ async def wait(self, timeout: float = 30.0, context: 'Context' = None): from aworld.logs.util import logger logger.info(f"Waiting for message {self.msg_id}") if context: - from aworld.core.common import TaskStatusValue + from aworld.core.common import TaskStatus task_status = await context.get_task_status() - if (task_status == TaskStatusValue.CANCELLED - or task_status == TaskStatusValue.INTERRUPTED): + if (task_status == TaskStatus.CANCELLED + or task_status == TaskStatus.INTERRUPTED): self.set_empty_result(msg=f"Task {task_status.lower()}: message not sent") logger.info(f"Task {task_status.lower()}: message not sent") return self.result() diff --git a/aworld/core/task.py b/aworld/core/task.py index 27ad50474..c0aa3a1bb 100644 --- a/aworld/core/task.py +++ b/aworld/core/task.py @@ -1,18 +1,16 @@ # coding: utf-8 # Copyright (c) 2025 inclusionAI. import abc -import asyncio -import enum import uuid from dataclasses import dataclass, field -from typing import Any, Union, List, Dict, Callable, Optional, Literal, TYPE_CHECKING, AsyncGenerator +from typing import Any, Union, List, Dict, Callable, Optional, AsyncGenerator from aworld.core.event.base import Message from aworld.utils.serialized_util import to_serializable from aworld.agents.llm_agent import Agent from aworld.core.agent.swarm import Swarm -from aworld.core.common import Config, Observation, StreamingMode, TaskStatus, TaskStatusValue +from aworld.core.common import Config, Observation, StreamingMode, TaskStatus from aworld.core.context.base import Context from aworld.core.tool.base import Tool, AsyncTool from aworld.output.outputs import Outputs, DefaultOutputs @@ -58,7 +56,7 @@ class Task: max_retry_count: int = field(default=0) timeout: int = field(default=0) observation: Optional[Observation] = field(default=None) - task_status: TaskStatus = field(default=TaskStatusValue.INIT) + task_status: TaskStatus = field(default=TaskStatus.INIT) # streaming support streaming_mode: StreamingMode = field(default=None) @@ -109,9 +107,9 @@ class TaskResponse: time_cost: float | None = field(default=0.0) success: bool = field(default=False) msg: str | None = field(default=None) - trajectory: List[Dict[str, Any]]= field(default_factory=list) + trajectory: List[Dict[str, Any]] = field(default_factory=list) # task final status, e.g. success/failed/cancelled - status: TaskStatus | None = field(default=TaskStatusValue.SUCCESS) + status: str | None = field(default=TaskStatus.SUCCESS) def to_dict(self) -> Dict[str, Any]: return { diff --git a/aworld/evaluations/scorers/output_validators.py b/aworld/evaluations/scorers/output_validators.py index c76344789..4e56adb58 100644 --- a/aworld/evaluations/scorers/output_validators.py +++ b/aworld/evaluations/scorers/output_validators.py @@ -343,13 +343,6 @@ def build_judge_data(self, index: int, input: Any, output: Any) -> str: @scorer_register( MetricNames.OUTPUT_RELEVANCE, - model_config=ModelConfig( - llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")), - llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")), - llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))), - llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")), - llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")), - ) ) class OutputRelevanceScorer(OutputLlmScore): """Verify the correlation between the answer and the question @@ -369,13 +362,6 @@ def _build_judge_system_prompt(self) -> str: @scorer_register( MetricNames.OUTPUT_COMPLETENESS, - model_config=ModelConfig( - llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")), - llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")), - llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))), - llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")), - llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")), - ) ) class OutputCompletenessScorer(OutputLlmScore): """Verify the completeness of the answer @@ -406,13 +392,13 @@ def build_judge_data(self, index: int, input: EvalDataCase, output: Any) -> str: @scorer_register( MetricNames.OUTPUT_QUALITY, - model_config=ModelConfig( - llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")), - llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")), - llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))), - llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")), - llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")), - ) + # model_config=ModelConfig( + # llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")), + # llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")), + # llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))), + # llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")), + # llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")), + # ) ) class OutputQualityScorer(OutputLlmScore): """Comprehensive evaluation of answer quality diff --git a/aworld/events/util.py b/aworld/events/util.py index 42c6f18d7..fd04d3b77 100644 --- a/aworld/events/util.py +++ b/aworld/events/util.py @@ -3,7 +3,7 @@ from typing import Callable, Any, List import asyncio -from aworld.core.common import TaskStatusValue +from aworld.core.common import TaskStatus from aworld.core.context.base import Context from aworld.events import eventbus from aworld.core.event.base import Message, Constants @@ -56,9 +56,9 @@ async def send_message(msg: Message): """ context = msg.context if context: - from aworld.core.common import TaskStatusValue + from aworld.core.common import TaskStatus task_status = await context.get_task_status() - if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED: + if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED: await _send_finish_message(msg, task_status) return await _send_message(msg) @@ -78,9 +78,9 @@ async def send_and_wait_message(msg: Message) -> List['HandleResult'] | None: """ context = msg.context if context: - from aworld.core.common import TaskStatusValue + from aworld.core.common import TaskStatus task_status = await context.get_task_status() - if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED: + if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED: await _send_finish_message(msg, task_status) return None await _send_message(msg) @@ -184,9 +184,9 @@ async def send_message_with_future(msg: Message) -> MessageFuture: if context: - from aworld.core.common import TaskStatusValue + from aworld.core.common import TaskStatus task_status = await context.get_task_status() - if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED: + if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED: await _send_finish_message(msg, task_status) # Task cancelled or interrupted, return a completed Future with empty result dummy_msg_id = f"cancelled_{msg.id}" @@ -199,7 +199,7 @@ async def send_message_with_future(msg: Message) -> MessageFuture: future = MessageFuture(msg_id) return future -async def _send_finish_message(msg: Message, status: str = TaskStatusValue.SUCCESS): +async def _send_finish_message(msg: Message, status: str = TaskStatus.SUCCESS): context = msg.context await _send_message(Message(payload=f"Task {status.lower()}",session_id=context.session_id, category=Constants.TASK, headers={"context": context})) diff --git a/aworld/memory/utils.py b/aworld/memory/utils.py deleted file mode 100644 index 1b094df73..000000000 --- a/aworld/memory/utils.py +++ /dev/null @@ -1,16 +0,0 @@ - - -from aworld.core.memory import MemoryItem - - -def build_history_context(history_messages: list[MemoryItem]) -> str: - """ - Build history context from history messages. - """ - history_context = "" - for message in history_messages: - if message.role == "user": - history_context += f"User: {message.content}\n" - else: - history_context += f"Agent: {message.content}\n" - return history_context \ No newline at end of file diff --git a/aworld/runners/event_runner.py b/aworld/runners/event_runner.py index daab94bcc..29e17dee0 100644 --- a/aworld/runners/event_runner.py +++ b/aworld/runners/event_runner.py @@ -15,7 +15,7 @@ from aworld.dataset.trajectory_storage import get_storage_instance from aworld.core.event.base import Message, Constants, TopicType, ToolMessage, AgentMessage from aworld.core.exceptions import AWorldRuntimeException -from aworld.core.task import Task, TaskResponse, TaskStatusValue +from aworld.core.task import Task, TaskResponse, TaskStatus from aworld.dataset.trajectory_dataset import TrajectoryDataset from aworld.events.manager import EventManager from aworld.logs.util import logger, trajectory_logger @@ -349,7 +349,7 @@ async def _do_run(self): time_cost=( time.time() - start), usage=self.context.token_usage, - status=TaskStatusValue.SUCCESS if not msg else TaskStatusValue.FAILED) + status=TaskStatus.SUCCESS if not msg else TaskStatus.FAILED) break logger.debug(f"{task_flag} task {self.task.id} next message snap") # consume message @@ -424,7 +424,7 @@ def _response(self): self._task_response = TaskResponse(id=self.context.task_id if self.context else "", success=False, msg="Task return None.", - status=TaskStatusValue.FAILED) + status=TaskStatus.FAILED) if self.context.get_task().conf and self.context.get_task().conf.resp_carry_raw_llm_resp == True: self._task_response.raw_llm_resp = self.context.context_info.get('llm_output') self._task_response.trace_id = get_trace_id() @@ -466,14 +466,14 @@ async def should_stop_task(self, message: Message): time_cost=(time.time() - self.start_time), usage=self.context.token_usage, msg=f'Task timeout after {time_cost} seconds.', - status=TaskStatusValue.TIMEOUT + status=TaskStatus.TIMEOUT ) - await self.context.update_task_status(self.task.id, TaskStatusValue.TIMEOUT) + await self.context.update_task_status(self.task.id, TaskStatus.TIMEOUT) return True # Check Task status from context task_status = await self.context.get_task_status() - if task_status == TaskStatusValue.INTERRUPTED or task_status == TaskStatusValue.CANCELLED: + if task_status == TaskStatus.INTERRUPTED or task_status == TaskStatus.CANCELLED: logger.warn(f"{task_flag} task {self.task.id} is {task_status}.") self._task_response = TaskResponse( answer='', diff --git a/aworld/runners/handler/background_task.py b/aworld/runners/handler/background_task.py index 6ded0004c..839d3fe4f 100644 --- a/aworld/runners/handler/background_task.py +++ b/aworld/runners/handler/background_task.py @@ -1,25 +1,18 @@ # coding: utf-8 # Copyright (c) 2025 inclusionAI. import abc -import asyncio -import time from typing import AsyncGenerator, TYPE_CHECKING, Tuple -from env_channel import EnvChannelMessage - from aworld.core.common import TaskItem, Observation -from aworld.core.context.amni import get_context_manager, ContextManager, ApplicationContext, AmniContext from aworld.core.context.base import Context from aworld.core.event.base import Message, Constants, TopicType, BackgroundTaskMessage, AgentMessage -from aworld.core.task import TaskResponse, TaskStatusValue, Task, Runner +from aworld.core.task import TaskResponse, TaskStatus, Task, Runner from aworld.events.util import send_message from aworld.logs.util import logger from aworld.memory.main import MemoryFactory from aworld.memory.models import MemoryHumanMessage, MessageMetadata -from aworld.runner import Runners from aworld.runners import HandlerFactory from aworld.runners.handler.base import DefaultHandler -from aworld.runners.hook.hooks import HookPoint if TYPE_CHECKING: from aworld.runners.event_runner import TaskEventRunner @@ -166,11 +159,11 @@ async def _merge_by_topic(self, message: Message): content = data.content elif isinstance(data, Observation): content = data.content - elif isinstance(data, EnvChannelMessage): - data = data.message - if not agent_id: - agent_id = data.get('env_content', {}).get('agent_id') - content = data + # elif isinstance(data, EnvChannelMessage): + # data = data.message + # if not agent_id: + # agent_id = data.get('env_content', {}).get('agent_id') + # content = data elif isinstance(data, dict): if not agent_id: agent_id = data.get('env_content', {}).get('agent_id') diff --git a/aworld/runners/handler/task.py b/aworld/runners/handler/task.py index 5e99f949b..42bd12572 100644 --- a/aworld/runners/handler/task.py +++ b/aworld/runners/handler/task.py @@ -8,7 +8,7 @@ from aworld.core.common import TaskItem from aworld.core.event.base import Message, Constants, TopicType -from aworld.core.task import TaskResponse, TaskStatusValue +from aworld.core.task import TaskResponse, TaskStatus from aworld.core.tool.base import Tool, AsyncTool from aworld.logs.util import logger, trajectory_logger from aworld.output import Output @@ -88,7 +88,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: id=self.runner.task.id, time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage, - status=TaskStatusValue.FAILED) + status=TaskStatus.FAILED) await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, @@ -149,7 +149,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage, msg=f'cancellation message received: {task_item.msg}', - status=TaskStatusValue.CANCELLED) + status=TaskStatus.CANCELLED) await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE) @@ -165,7 +165,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage, msg=f'interruption message received: {task_item.msg}', - status=TaskStatusValue.INTERRUPTED) + status=TaskStatus.INTERRUPTED) await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE) diff --git a/aworld/runners/hook/agent_hooks.py b/aworld/runners/hook/agent_hooks.py index fc695cffa..21eef232e 100644 --- a/aworld/runners/hook/agent_hooks.py +++ b/aworld/runners/hook/agent_hooks.py @@ -1,48 +1,38 @@ # coding: utf-8 # Copyright (c) 2025 inclusionAI. -import abc - from aworld.core.context.base import Context from aworld.core.event.base import Message from aworld.runners.hook.hook_factory import HookFactory -from aworld.runners.hook.hooks import PostLLMCallHook, PreLLMCallHook +from aworld.runners.hook.hooks import OnStartLLMCallHook, OnFinishedLLMCallHook from aworld.utils.common import convert_to_snake @HookFactory.register(name="PreLLMCallContextProcessHook", desc="PreLLMCallContextProcessHook") -class PreLLMCallContextProcessHook(PreLLMCallHook): +class PreLLMCallContextProcessHook(OnStartLLMCallHook): """Process in the hook point of the pre_llm_call.""" - __metaclass__ = abc.ABCMeta def name(self): return convert_to_snake("PreLLMCallContextProcessHook") - async def exec(self, message: Message, context: Context = None) -> Message: - # and do something - pass - @HookFactory.register(name="PostLLMCallContextProcessHook", desc="PostLLMCallContextProcessHook") -class PostLLMCallContextProcessHook(PostLLMCallHook): +class PostLLMCallContextProcessHook(OnFinishedLLMCallHook): """Process in the hook point of the post_llm_call.""" - __metaclass__ = abc.ABCMeta def name(self): return convert_to_snake("PostLLMCallContextProcessHook") - async def exec(self, message: Message, context: Context = None) -> Message: - # get context - pass - @HookFactory.register(name="PostLLMTrajectoryHook", desc="PostLLMTrajectoryHook") -class PostLLMTrajectoryHook(PostLLMCallHook): +class PostLLMTrajectoryHook(OnFinishedLLMCallHook): """Update trajectory after llm call.""" + def name(self): return convert_to_snake("PostLLMTrajectoryHook") + async def exec(self, message: Message, context: Context = None) -> Message: # get context agent_message = message.headers.get("agent_message") diff --git a/aworld/runners/hook/hook_factory.py b/aworld/runners/hook/hook_factory.py index b0f7b94a3..c3579b215 100644 --- a/aworld/runners/hook/hook_factory.py +++ b/aworld/runners/hook/hook_factory.py @@ -5,7 +5,7 @@ from aworld.core.factory import Factory from aworld.logs.util import logger -from aworld.runners.hook.hooks import Hook, StartHook, HookPoint +from aworld.runners.hook.hooks import Hook, HookPoint class HookManager(Factory): diff --git a/aworld/runners/hook/hooks.py b/aworld/runners/hook/hooks.py index bb4305c9c..b9a961c32 100644 --- a/aworld/runners/hook/hooks.py +++ b/aworld/runners/hook/hooks.py @@ -4,20 +4,29 @@ from aworld.core.context.base import Context from aworld.core.event.base import Message -from aworld.models.model_response import ModelResponse class HookPoint: START = "start" FINISHED = "finished" ERROR = "error" - PRE_LLM_CALL = "pre_llm_call" - POST_LLM_CALL = "post_llm_call" OUTPUT_PROCESS = "output_process" - PRE_TOOL_CALL = "pre_tool_call" - POST_TOOL_CALL = "post_tool_call" - PRE_TASK_CALL = "pre_task_call" - POST_TASK_CALL = "post_task_call" + ON_START_LLM_CALL = "on_start_llm_call" + ON_FINISHED_LLM_CALL = "on_finished_llm_call" + ON_LLM_CALL = "on_llm_call" + ON_SUCCESS_LLM_CALL = "on_success_llm_call" + ON_ERROR_LLM_CALL = "on_error_llm_call" + ON_START_TOOL_CALL = "on_start_tool_call" + ON_TOOL_CALL = "on_tool_call" + ON_FINISHED_TOOL_CALL = "on_finished_tool_call" + ON_SUCCESS_TOOL_CALL = "on_success_tool_call" + ON_ERROR_TOOL_CALL = "on_error_tool_call" + ON_RUN_TASK = "on_run_task" + ON_SUCCESS_TASK = "on_success_task" + ON_ERROR_TASK = "on_error_task" + ON_START_TASK = "on_start_task" + ON_FINISHED_TASK = "on_finished_task" + class Hook: """Runner hook.""" @@ -27,14 +36,17 @@ class Hook: def point(self): """Hook point.""" - @abc.abstractmethod + def name(self): + """Hook name.""" + return self.__class__.__name__ + async def exec(self, message: Message, context: Context = None) -> Message: """Execute hook function.""" + pass class StartHook(Hook): """Process in the hook point of the start.""" - __metaclass__ = abc.ABCMeta def point(self): return HookPoint.START @@ -42,7 +54,6 @@ def point(self): class FinishedHook(Hook): """Process in the hook point of the finished.""" - __metaclass__ = abc.ABCMeta def point(self): return HookPoint.FINISHED @@ -50,35 +61,38 @@ def point(self): class ErrorHook(Hook): """Process in the hook point of the error.""" - __metaclass__ = abc.ABCMeta def point(self): return HookPoint.ERROR -class PreLLMCallHook(Hook): - """Process in the hook point of the pre_llm_call.""" - __metaclass__ = abc.ABCMeta +class OnStartLLMCallHook(Hook): def point(self): - return HookPoint.PRE_LLM_CALL - -class PostLLMCallHook(Hook): - """Process in the hook point of the post_llm_call.""" - __metaclass__ = abc.ABCMeta + return HookPoint.ON_START_LLM_CALL + +class OnFinishedLLMCallHook(Hook): def point(self): - return HookPoint.POST_LLM_CALL + return HookPoint.ON_FINISHED_LLM_CALL -class PostToolCallHook(Hook): - """Process in the hook point of the post_tool_call.""" - __metaclass__ = abc.ABCMeta +class OnLLMCallHook(Hook): def point(self): - return HookPoint.POST_TOOL_CALL + return HookPoint.ON_LLM_CALL + + +class OnSuccessLLMCallHook(Hook): + def point(self): + return HookPoint.ON_SUCCESS_LLM_CALL + + +class OnErrorLLMCallHook(Hook): + def point(self): + return HookPoint.ON_ERROR_LLM_CALL + class OutputProcessHook(Hook): """Output process hook for processing output data for display.""" - __metaclass__ = abc.ABCMeta def point(self): return HookPoint.OUTPUT_PROCESS @@ -93,34 +107,3 @@ def process_output_content(self, content: str) -> str: processed content """ return content - - -class PreToolCallHook(Hook): - """Process in the hook point of the pre_tool_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.PRE_TOOL_CALL - - -class PostToolCallHook(Hook): - """Process in the hook point of the post_tool_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.POST_TOOL_CALL - -class PreTaskCallHook(Hook): - """Process in the hook point of the post_task_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.PRE_TASK_CALL - -class PostTaskCallHook(Hook): - """Process in the hook point of the post_task_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.POST_TASK_CALL - diff --git a/aworld/runners/hook/task_hooks.py b/aworld/runners/hook/task_hooks.py new file mode 100644 index 000000000..1060f720a --- /dev/null +++ b/aworld/runners/hook/task_hooks.py @@ -0,0 +1,28 @@ +# coding: utf-8 +# Copyright (c) 2025 inclusionAI. +from aworld.runners.hook.hooks import Hook, HookPoint + + +class OnRunHook(Hook): + def point(self): + return HookPoint.ON_RUN_TASK + + +class OnSuccessHook(Hook): + def point(self): + return HookPoint.ON_SUCCESS_TASK + + +class OnErrorHook(Hook): + def point(self): + return HookPoint.ON_ERROR_TASK + + +class OnStartHook(Hook): + def point(self): + return HookPoint.ON_START_TASK + + +class OnFinishHook(Hook): + def point(self): + return HookPoint.ON_FINISHED_TASK diff --git a/aworld/runners/hook/tool_hooks.py b/aworld/runners/hook/tool_hooks.py new file mode 100644 index 000000000..310b3e476 --- /dev/null +++ b/aworld/runners/hook/tool_hooks.py @@ -0,0 +1,28 @@ +# coding: utf-8 +# Copyright (c) 2025 inclusionAI. +from aworld.runners.hook.hooks import Hook, HookPoint + + +class OnStartToolCallHook(Hook): + def point(self): + return HookPoint.ON_START_TOOL_CALL + + +class OnFinishedToolCallHook(Hook): + def point(self): + return HookPoint.ON_FINISHED_TOOL_CALL + + +class OnToolCallHook(Hook): + def point(self): + return HookPoint.ON_TOOL_CALL + + +class OnSuccessToolCallHook(Hook): + def point(self): + return HookPoint.ON_SUCCESS_TOOL_CALL + + +class OnErrorToolCallHook(Hook): + def point(self): + return HookPoint.ON_ERROR_TOOL_CALL diff --git a/aworld/runners/task_manager.py b/aworld/runners/task_manager.py new file mode 100644 index 000000000..bfbe7f440 --- /dev/null +++ b/aworld/runners/task_manager.py @@ -0,0 +1,513 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import time +from typing import List, Optional, Set + +from aworld.core.storage.base import Storage +from aworld.core.storage.data import Data +from aworld.core.storage.inmemory_store import InmemoryStorage +from aworld.core.task import Task +from aworld.logs.util import logger + +from aworld.core.common import TaskStatus + + +class TaskManager: + """High-level task store providing simple API for task persistence. + + TaskStore acts as a facade over TaskStorage, providing: + - Simple CRUD operations + - Task querying and filtering + - Status-based retrieval + - Ready task detection + + Examples: + # Initialize with storage (required) + + manager = TaskManager(storage=InMemoryStorage()) + + # Add task + task = ScheduledTask(id="task1_id", name="My Task") + await manager.add_task(task) + + # Get task + task = await manager.get_task("task1_id") + + # List tasks + all_tasks = await manager.list() + pending_tasks = await manager.list(status=TaskStatus.INIT) + + # Get ready tasks + ready = await manager.get_ready() + """ + + def __init__(self, storage: Storage = InmemoryStorage()): + if storage is None: + raise ValueError("storage parameter is required") + + if not isinstance(storage, Storage): + raise TypeError(f"storage must be TaskStorage instance, got {type(storage)}") + + self.storage = storage + self._cache_completed: Set[str] = set() + + async def add_task(self, task: Task, overwrite: bool = True) -> bool: + """Add a task to the store. + + Args: + task: Task to add + overwrite: Whether to overwrite if task exists + + Returns: + bool: True if successful + + Raises: + ValueError: If task.id is empty + """ + if not task.id: + raise ValueError("Task ID cannot be empty") + + try: + success = await self.storage.create_data( + Data(value=task), + block_id=task.id, + overwrite=overwrite + ) + + if success: + logger.debug(f"Added task: {task.id}") + + return success + except Exception as e: + logger.error(f"Failed to add task {task.id}: {e}") + return False + + async def add_batch(self, tasks: List[Task], overwrite: bool = True) -> int: + """Add multiple tasks in batch. + + Args: + tasks: List of tasks to add + overwrite: Whether to overwrite existing tasks + + Returns: + int: Number of successfully added tasks + """ + success_count = 0 + for task in tasks: + if await self.add_task(task, overwrite=overwrite): + success_count += 1 + + logger.debug(f"Added {success_count}/{len(tasks)} tasks") + return success_count + + async def get_task(self, task_id: str) -> Optional[Task]: + """Get a task by ID. + + Args: + task_id: Task ID + + Returns: + Task or None if not found + """ + try: + tasks = await self.storage.get_data_items(block_id=task_id) + return tasks[0].value if tasks else None + except Exception as e: + logger.error(f"Failed to get task {task_id}: {e}") + return None + + async def update_task(self, task: Task) -> bool: + """Update an existing task. + + Args: + task: Updated task + + Returns: + bool: True if successful + """ + try: + success = await self.storage.update_data( + Data(value=task), + block_id=task.id, + exists=True + ) + + if success: + logger.debug(f"Updated task: {task.id}") + return success + except Exception as e: + logger.error(f"Failed to update task {task.id}: {e}") + return False + + async def update_status( + self, + task_id: str, + status: str, + **fields + ) -> bool: + """Update task status and optional fields. + + Args: + task_id: Task ID + status: New status + **fields: Additional fields to update (e.g., started_at, completed_at) + + Returns: + bool: True if successful + """ + task = await self.get_task(task_id) + if not task: + logger.warning(f"Task {task_id} not found") + return False + + task.task_status = status + + for field, value in fields.items(): + if hasattr(task, field): + setattr(task, field, value) + + return await self.update_task(task) + + async def delete_task(self, task_id: str) -> bool: + """Delete a task. + + Args: + task_id: Task ID + + Returns: + bool: True if successful + """ + try: + success = await self.storage.delete_data( + task_id, + block_id=task_id, + exists=False + ) + + if success: + logger.debug(f"Deleted task: {task_id}") + self._cache_completed.discard(task_id) + + return success + except Exception as e: + logger.error(f"Failed to delete task {task_id}: {e}") + return False + + async def exists(self, task_id: str) -> bool: + """Check if a task exists. + + Args: + task_id: Task ID + + Returns: + bool: True if task exists + """ + return await self.get_task(task_id) is not None + + async def list( + self, + status: Optional[str] = None, + limit: Optional[int] = None, + offset: int = 0 + ) -> List[Task]: + """List tasks with optional filtering. + + Args: + status: Filter by status (optional) + limit: Maximum number of tasks to return + offset: Number of tasks to skip + + Returns: + List of tasks + """ + try: + if status: + tasks = await self.get_tasks_by_status(status) + else: + tasks = await self.storage.get_data_items() + + # Apply offset and limit + tasks = tasks[offset:] if offset > 0 else tasks + tasks = tasks[:limit] if limit else tasks + return tasks + except Exception as e: + logger.error(f"Failed to list tasks: {e}") + return [] + + async def get_ready( + self, + current_time: Optional[float] = None, + limit: Optional[int] = None + ) -> List[Task]: + """Get tasks that are ready to execute. + + Args: + current_time: Current timestamp (default: now) + limit: Maximum number of tasks + + Returns: + List of ready tasks sorted by priority and time + """ + current_time = current_time if current_time else time.time() + + try: + ready_tasks = await self.get_ready_tasks(current_time, limit) + return ready_tasks + except Exception as e: + logger.error(f"Failed to get ready tasks: {e}") + return [] + + async def get_pending(self, limit: Optional[int] = None) -> List[Task]: + """Get all pending (INIT status) tasks.""" + return await self.list(status=TaskStatus.INIT, limit=limit) + + async def get_running(self, limit: Optional[int] = None) -> List[Task]: + """Get all running tasks.""" + return await self.list(status=TaskStatus.RUNNING, limit=limit) + + async def get_completed(self, limit: Optional[int] = None) -> List[Task]: + """Get all completed (SUCCESS status) tasks.""" + return await self.list(status=TaskStatus.SUCCESS, limit=limit) + + async def get_failed(self, limit: Optional[int] = None) -> List[Task]: + """Get all failed tasks.""" + return await self.list(status=TaskStatus.FAILED, limit=limit) + + async def get_periodic(self) -> List[Task]: + """Get all periodic tasks (with cron expression).""" + try: + return await self.get_periodic_tasks() + except Exception as e: + logger.error(f"Failed to get periodic tasks: {e}") + return [] + + async def count(self, status: Optional[str] = None) -> int: + """ + Count tasks, optionally filtered by status. + + Args: + status: Filter by status (optional) + + Returns: + Number of tasks + """ + try: + if status: + return await self.count_by_status(status) + else: + return await self.storage.size() + except Exception as e: + logger.error(f"Failed to count tasks: {e}") + return 0 + + async def clear(self) -> bool: + """ + Clear all tasks (use with caution). + + Returns: + bool: True if successful + """ + try: + await self.storage.delete_all() + self._cache_completed.clear() + logger.warning("Cleared all tasks from store") + return True + except Exception as e: + logger.error(f"Failed to clear tasks: {e}") + return False + + async def cleanup_completed( + self, + before_time: Optional[float] = None, + keep_periodic: bool = True + ) -> int: + """Clean up old completed tasks. + + Args: + before_time: Remove tasks completed before this time (default: 7 days ago) + keep_periodic: Keep periodic tasks even if completed + + Returns: + int: Number of tasks removed + """ + if before_time is None: + before_time = time.time() - (7 * 24 * 3600) + + try: + completed_tasks = await self.get_completed() + removed_count = 0 + + for task in completed_tasks: + # Check if task has completed_at attribute (for compatibility) + completed_at = getattr(task, 'completed_at', None) + is_periodic = getattr(task, 'is_periodic', False) + + # Determine if should remove + should_remove = False + if completed_at is not None and completed_at < before_time: + # Only remove if not keeping periodic tasks or task is not periodic + if not keep_periodic or not is_periodic: + should_remove = True + elif completed_at is None: + # For tasks without completed_at, use created_at or current time + created_at = getattr(task, 'created_at', None) + if created_at is not None and created_at < before_time: + if not keep_periodic or not is_periodic: + should_remove = True + + if should_remove: + if await self.delete_task(task.id): + removed_count += 1 + + logger.info(f"Cleaned up {removed_count} completed tasks") + return removed_count + except Exception as e: + logger.error(f"Failed to cleanup completed tasks: {e}") + return 0 + + async def get_statistics(self) -> dict: + """ + Get task statistics. + + Returns: + Dictionary with task counts by status + """ + try: + total = await self.count() + pending = await self.count(TaskStatus.INIT) + running = await self.count(TaskStatus.RUNNING) + completed = await self.count(TaskStatus.SUCCESS) + failed = await self.count(TaskStatus.FAILED) + + return { + "total": total, + "pending": pending, + "running": running, + "completed": completed, + "failed": failed, + } + except Exception as e: + logger.error(f"Failed to get statistics: {e}") + return {} + + def mark_completed(self, task_id: str): + """Mark a task as completed in cache (for dependency tracking).""" + self._cache_completed.add(task_id) + + def get_completed_ids(self) -> Set[str]: + """Get set of completed task IDs from cache.""" + return self._cache_completed.copy() + + async def get_tasks_by_status(self, status: str, limit: Optional[int] = None) -> List[Task]: + # Create condition for status filtering + from aworld.core.storage.condition import Condition + + try: + condition = Condition().add("task_status", "==", status) + tasks = await self.storage.select_data(condition) + + # Sort by next_run_time (earlier first), with None values at the end + if tasks and hasattr(tasks[0], "next_run_time"): + tasks.sort(key=lambda t: (t.next_run_time is None, t.next_run_time or 0)) + if limit: + tasks = tasks[:limit] + return tasks + except Exception as e: + logger.error(f"Failed to get tasks by status {status}: {e}") + return [] + + async def get_ready_tasks(self, + current_time: Optional[float] = None, + limit: Optional[int] = None) -> List[Task]: + """Get tasks that are ready to execute. + + A task is ready if: + 1. Status is INIT (pending) + 2. Current time >= next_run_time (or scheduled_time) + 3. All dependencies are completed + 4. Within start_time and end_time constraints + 5. Not exceeded max_executions + + Args: + current_time: Current timestamp (default: now) + limit: Maximum number of tasks to return + + Returns: + List of ready tasks sorted by priority (high to low) and next_run_time + """ + current_time = current_time if current_time else time.time() + try: + pending_tasks = await self.get_tasks_by_status(TaskStatus.INIT) + ready_tasks = [] + for task in pending_tasks: + # Check if task has is_ready method (for ScheduledTask compatibility) + if hasattr(task, 'is_ready') and callable(getattr(task, 'is_ready')): + # Use task's is_ready method + if task.is_ready(self._cache_completed, current_time): + ready_tasks.append(task) + else: + # For basic Task without is_ready, check basic readiness + # Check dependencies if task has them + dependencies = getattr(task, 'dependencies', []) + deps_ready = all(dep_id in self._cache_completed for dep_id in dependencies) + + # Check time constraints if task has next_run_time + next_run_time = getattr(task, 'next_run_time', None) + time_ready = True + if next_run_time is not None: + time_ready = current_time >= next_run_time + + if deps_ready and time_ready: + ready_tasks.append(task) + + # Sort by priority (higher first) and next_run_time (earlier first) + if ready_tasks: + def sort_key(t): + priority = getattr(t, "priority", 0) + next_run_time = getattr(t, "next_run_time", None) + return (-priority, next_run_time if next_run_time is not None else float('inf')) + + ready_tasks.sort(key=sort_key) + + ready_tasks = ready_tasks[:limit] if limit else ready_tasks + return ready_tasks + except Exception as e: + logger.error(f"Failed to get ready tasks: {e}") + return [] + + async def get_periodic_tasks(self) -> List[Task]: + """Get all periodic tasks (tasks with cron expression). + + Returns: + List of periodic tasks + """ + try: + # Get all tasks + all_tasks = await self.storage.get_data_items() + + # Filter periodic tasks + return [task for task in all_tasks if hasattr(task, 'is_periodic') and task.is_periodic] + except Exception as e: + logger.error(f"Failed to get periodic tasks: {e}") + return [] + + async def count_by_status(self, status: str) -> int: + """Count tasks by status. + + Args: + status: Task status to count + + Returns: + Number of tasks with the specified status + """ + try: + from aworld.core.storage.condition import Condition + + # Create condition for status filtering + condition = Condition().add("task_status", "==", status) + count = await self.storage.size(condition) + + return count + except Exception as e: + logger.error(f"Failed to count tasks by status {status}: {e}") + return 0 diff --git a/aworld/schedule/__init__.py b/aworld/schedule/__init__.py new file mode 100644 index 000000000..e8b7477a3 --- /dev/null +++ b/aworld/schedule/__init__.py @@ -0,0 +1,2 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. diff --git a/aworld/schedule/scheduler.py b/aworld/schedule/scheduler.py new file mode 100644 index 000000000..4ad6dfc3c --- /dev/null +++ b/aworld/schedule/scheduler.py @@ -0,0 +1,501 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import asyncio +import time +from typing import Optional, List, Callable, Awaitable, Dict, Set, Any + +from aworld.logs.util import logger +from aworld.runners.hook.hooks import Hook +from aworld.runners.hook.utils import run_hooks +from aworld.runners.runtime_engine import RuntimeEngine +from aworld.runners.task_manager import TaskManager +from aworld.runners.utils import runtime_engine +from aworld.schedule.strategy import ScheduleStrategy, create_strategy +from aworld.schedule.types import ScheduledTask, ResourceQuota, TaskStatistics +from aworld.core.common import TaskStatus +from aworld.core.task import Task, TaskResponse +from aworld.utils.run_util import exec_tasks + + +class TaskScheduler: + """ + Task Scheduler for managing and executing scheduled tasks. + + The scheduler supports: + - Multiple scheduling strategies (FIFO, Priority, DAG, Auto, ...) + - Resource quota management + - Concurrent task execution + - Periodic task scheduling + - Task dependency resolution + - Lifecycle hooks + + Examples: + # Create scheduler with storage + from aworld.core.storage.inmemory_store import InmemoryStorage + manager = TaskManager(storage=InmemoryStorage()) + scheduler = TaskScheduler(task_manager=manager) + + # Add tasks + task1 = ScheduledTask(id="task1", name="Task 1", priority=10) + await scheduler.add_task(task1) + + # Set execution handler + async def execute_handler(task): + print(f"Executing {task.name}") + await asyncio.sleep(1) + return True + + # Run scheduler + await scheduler.run() + + # Or run scheduler in background + scheduler.start() + # ... do other work ... + await scheduler.stop() + """ + + def __init__( + self, + task_manager: TaskManager = TaskManager(), + strategy: Optional[ScheduleStrategy] = None, + strategy_type: str = 'auto', + executor: Optional[Callable[[Task], Awaitable[TaskResponse]]] = None, + resource_quota: Optional[ResourceQuota] = None, + max_concurrent: int = 10, + poll_interval: float = 1.0, + hooks: Dict[str, Hook] = None + ): + """ + Initialize TaskScheduler. + + Args: + task_manager: TaskManager instance (required) + strategy: Custom scheduling strategy (optional) + strategy_type: Type of schedule strategy ('auto', 'fifo', 'priority', 'dag', 'custom_name') + executor: Async function to execute a task (optional, default uses RuntimeEngine) + resource_quota: Resource quota for task execution + max_concurrent: Maximum concurrent tasks + poll_interval: Seconds between scheduling cycles + """ + self.task_manager = task_manager + self.strategy = strategy or create_strategy(strategy_type) + self.resource_quota = resource_quota or ResourceQuota(max_concurrent=max_concurrent) + self.poll_interval = poll_interval + + # Execution state + self._executor: Optional[Callable[[Task], Awaitable[TaskResponse]]] = executor or execute_schedulable_tasks + self._running = False + self._scheduler_task: Optional[asyncio.Task] = None + self._active_tasks: Dict[str, asyncio.Task] = {} + self._stop_event = asyncio.Event() + + # Statistics + self.statistics = TaskStatistics() + + @property + def executor(self): + return self._executor + + @executor.setter + def executor(self, executor: Callable[[Task], Awaitable[TaskResponse]]): + """ + Set the task executor function. + + Args: + executor: Async function that takes a Task and returns bool (success) + + Examples: + async def my_executor(task): + print(f"Executing {task.name}") + # ... do work ... + return TaskResponse(success=True, ...) + + scheduler.executor = my_executor + """ + self._executor = executor + + async def schedule(self, tasks: Optional[List[ScheduledTask]] = None) -> List[List[ScheduledTask]]: + """Schedule tasks into execution batches using the configured strategy. + + Args: + tasks: Tasks to schedule (default: get ready tasks from manager) + + Returns: + List of task batches, where each batch can be executed in parallel + """ + if tasks is None: + tasks = await self.task_manager.get_ready() + + if not tasks: + return [] + + # Apply scheduling strategy + batches = await self.strategy.schedule(tasks) + logger.debug(f"Scheduled {len(tasks)} tasks into {len(batches)} batches") + return batches + + async def execute(self, task: Task) -> TaskResponse: + """Execute a single task. + + Args: + task: Task to execute + + Returns: + bool: True if successful + """ + if not self._executor: + res = await execute_schedulable_tasks(await runtime_engine(), [task]) + return res.get(task.id) + + try: + # Mark task as started + await self.task_manager.update_status( + task.id, + TaskStatus.RUNNING, + started_at=time.time() + ) + + logger.info(f"Executing task: {task.id} ({task.name})") + + # Execute task + result: TaskResponse = await self._executor(task) + + # Mark task as completed + await self.task_manager.update_status( + task.id, + TaskStatus.SUCCESS if result.success else TaskStatus.FAILED, + completed_at=time.time() + ) + + # Update completed cache for dependency tracking + if result.success: + self.task_manager.mark_completed(task.id) + + logger.info(f"Task {'completed' if result.success else 'failed'}: {task.id}") + + # Handle periodic tasks + if result.success and hasattr(task, 'is_periodic') and task.is_periodic: + await self._reschedule_periodic_task(task) + + return result + except Exception as e: + logger.error(f"Error executing task {task.id}: {e}") + + # Mark task as failed + await self.task_manager.update_status( + task.id, + TaskStatus.FAILED, + completed_at=time.time() + ) + return TaskResponse() + + async def _reschedule_periodic_task(self, task: ScheduledTask): + """Reschedule a periodic task for next execution.""" + try: + # Update next run time + task.update_next_run_time() + + # Check if should continue + if task.max_executions and task.execution_count >= task.max_executions: + logger.info(f"Periodic task {task.id} reached max executions") + return + + if task.end_time and time.time() > task.end_time: + logger.info(f"Periodic task {task.id} reached end time") + return + + # Reset for next execution + task.task_status = TaskStatus.INIT + task.started_at = None + task.completed_at = None + + # Update in storage + await self.task_manager.update_task(task) + + logger.info(f"Rescheduled periodic task {task.id} for {task.next_run_time}") + + except Exception as e: + logger.error(f"Failed to reschedule periodic task {task.id}: {e}") + + async def execute_batch(self, batch: List[Task]) -> Dict[str, TaskResponse]: + """Execute a batch of tasks concurrently. + + Args: + batch: Batch of tasks to execute + + Returns: + Dict of response for each task + """ + if not batch: + return {} + + # Check resource availability + current_concurrent = len(self._active_tasks) + if self.resource_quota.max_concurrent > 0: + available_slots = self.resource_quota.max_concurrent - current_concurrent + if available_slots <= 0: + logger.warning("Max concurrent tasks reached, skipping batch") + return {} + + # Limit batch size to available slots + batch = batch[:available_slots] + + # Execute tasks concurrently + tasks = [self.execute(task) for task in batch] + results = await asyncio.gather(*tasks, return_exceptions=True) + + return {res.id: res for res in results} + + async def _run_once(self) -> int: + """Run one scheduling cycle. + + Returns: + int: Number of tasks executed + """ + # Get ready tasks + ready_tasks = await self.task_manager.get_ready() + + if not ready_tasks: + return 0 + + # Schedule tasks into batches + batches = await self.schedule(ready_tasks) + + executed_count = 0 + + # Execute batches + for batch in batches: + results = await self.execute_batch(batch) + executed_count += sum(1 for r in results if r) + + return executed_count + + async def run(self, timeout: Optional[float] = None): + """Run the scheduler for a limited time or iterations. + + Args: + timeout: Maximum time to run in seconds (None = unlimited) + + Examples: + # Run for 60 seconds + await scheduler.run(timeout=60) + + # Run until all tasks complete + await scheduler.run() + """ + start_time = time.time() + logger.info("Starting scheduler") + + try: + # Run one cycle + executed = await self._run_once() + + # Check if no tasks to execute + if executed == 0: + pending = await self.task_manager.count(TaskStatus.INIT) + if pending == 0: + logger.info("No more tasks to execute") + except Exception as e: + logger.error(f"Error in scheduler: {e}") + raise + finally: + logger.info(f"Scheduler completed, time cost {time.time() - start_time}(s)") + + async def _scheduler_loop(self): + """Main scheduler loop (for background execution).""" + logger.info("Scheduler loop started") + + try: + while self._running: + # Check stop event with timeout + try: + await asyncio.wait_for( + self._stop_event.wait(), + timeout=self.poll_interval + ) + # Stop event was set + break + except asyncio.TimeoutError: + # Timeout - continue with scheduling cycle + pass + + # Run one scheduling cycle + await self._run_once() + except Exception as e: + logger.error(f"Error in scheduler loop: {e}") + finally: + self._running = False + logger.info("Scheduler loop stopped") + + def start(self): + """Start the scheduler in background. + + Examples: + scheduler.start() + # ... scheduler runs in background ... + await scheduler.stop() + """ + if self._running: + logger.warning("Scheduler already running") + return + + self._running = True + self._stop_event.clear() + self._scheduler_task = asyncio.create_task(self._scheduler_loop()) + + logger.info("Scheduler started in background") + + async def stop(self, wait: bool = True): + """Stop the background scheduler. + + Args: + wait: Wait for scheduler to complete current cycle + """ + if not self._running: + logger.warning("Scheduler not running") + return + + logger.info("Stopping scheduler...") + + self._running = False + self._stop_event.set() + + if wait and self._scheduler_task: + await self._scheduler_task + + # Cancel active tasks + for task_id, task in self._active_tasks.items(): + if not task.done(): + task.cancel() + logger.debug(f"Cancelled active task: {task_id}") + + self._active_tasks.clear() + + logger.info("Scheduler stopped!") + + @property + def is_running(self) -> bool: + """Check if scheduler is running.""" + return self._running + + async def get_statistics(self) -> dict: + """Get scheduler statistics. + + Returns: + Dictionary with scheduler statistics + """ + stats = await self.task_manager.get_statistics() + + stats.update({ + "is_running": self.is_running, + "active_tasks": len(self._active_tasks), + "max_concurrent": self.resource_quota.max_concurrent, + "strategy": self.strategy.__class__.__name__, + "poll_interval": self.poll_interval, + }) + + return stats + + async def pause(self): + """Pause the scheduler (can be resumed).""" + if not self._running: + logger.warning("Scheduler not running") + return + + self._stop_event.set() + logger.info("Scheduler paused") + + async def resume(self): + """Resume the paused scheduler.""" + if self._running and self._stop_event.is_set(): + self._stop_event.clear() + logger.info("Scheduler resumed") + else: + logger.warning("Scheduler not paused or not running") + + async def add_task(self, task: ScheduledTask, overwrite: bool = True) -> bool: + """Add a task to the scheduler. + + Args: + task: Task to add + overwrite: Whether to overwrite existing task + + Returns: + bool: True if successful + """ + success = await self.task_manager.add_task(task, overwrite=overwrite) + + logger.info(f"Added task to scheduler {success}: {task.id} ({task.name})") + return success + + async def add_tasks(self, tasks: List[ScheduledTask], overwrite: bool = True) -> int: + """Add multiple tasks to the scheduler.""" + return await self.task_manager.add_batch(tasks, overwrite=overwrite) + + async def list_tasks(self, status: str = None, limit: Optional[int] = None, offset: int = 0) -> List[Task]: + tasks = await self.task_manager.list(status=status, limit=limit, offset=offset) + + logger.info(f"found {len(tasks)} tasks") + return tasks + + async def delete_task(self, task_id: str) -> bool: + success = await self.task_manager.delete_task(task_id=task_id) + + logger.info(f"Deleted task {success}: {task_id}") + return success + + async def cancel_task(self, task_id: str) -> bool: + """Cancel a scheduled task. + + Args: + task_id: ID of task to cancel + + Returns: + bool: True if successful + """ + # Cancel active execution + if task_id in self._active_tasks: + self._active_tasks[task_id].cancel() + del self._active_tasks[task_id] + + # Update task status + return await self.task_manager.update_status( + task_id, + TaskStatus.CANCELLED, + completed_at=time.time() + ) + + async def cleanup(self, before_time: Optional[float] = None): + """Clean up old completed tasks. + + Args: + before_time: Remove tasks completed before this time + """ + removed = await self.task_manager.cleanup_completed(before_time=before_time) + logger.info(f"Cleaned up {removed} old tasks") + return removed + + +async def execute_schedulable_tasks(runtime_engine: RuntimeEngine, tasks: List[Task]) -> Dict[str, Any]: + """Execute a list of scheduled tasks using the runtime engine, updates task status and start time before execution.""" + if not tasks: + return {} + + funcs = [] + # Wrap each task in an async executor function + for scheduled_task in tasks: + async def task_execute(st=scheduled_task): + st.task_status = TaskStatus.RUNNING + st.started_at = time.time() + res = await exec_tasks(tasks=[st]) + task_response = res.get(st.id) + if task_response: + st.task_status = task_response.status + st.completed_at = time.time() + return task_response + + funcs.append(task_execute) + + results = await runtime_engine.execute(funcs) + logger.info(f"{runtime_engine.name} execute {len(tasks)} tasks finished") + return results diff --git a/aworld/schedule/strategy.py b/aworld/schedule/strategy.py new file mode 100644 index 000000000..9e4b663a4 --- /dev/null +++ b/aworld/schedule/strategy.py @@ -0,0 +1,341 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import heapq +from abc import abstractmethod, ABC +from collections import deque +from typing import List, Optional, Dict, Type + +from aworld.logs.util import logger +from aworld.schedule.task_graph import TaskGraph +from aworld.schedule.types import ScheduledTask, ResourceQuota + + +class ScheduleStrategy(ABC): + """Abstract base class for scheduling strategies.""" + + @abstractmethod + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """ + Schedule tasks and return execution batches. + + Returns: + List of task batches, where each batch can be executed in parallel. + """ + + +class FIFOStrategy(ScheduleStrategy): + """First-In-First-Out scheduling default strategy.""" + + def __init__(self, batch_size: int = 10, **kwargs): + self.batch_size = batch_size + + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """Schedule tasks in FIFO order.""" + queue = FIFOTaskQueue() + for task in tasks: + queue.push(task) + + batches = [] + current_batch = [] + + while not queue.is_empty(): + task = queue.pop() + if task: + current_batch.append(task) + + if len(current_batch) >= self.batch_size: + batches.append(current_batch) + current_batch = [] + + if current_batch: + batches.append(current_batch) + + return batches + + +class PriorityStrategy(ScheduleStrategy): + """Priority-based scheduling strategy.""" + + def __init__(self, batch_size: int = 10, **kwargs): + self.batch_size = batch_size + + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """Schedule tasks by priority.""" + queue = PriorityTaskQueue() + for task in tasks: + queue.push(task) + + batches = [] + current_batch = [] + + while not queue.is_empty(): + task = queue.pop() + if task: + current_batch.append(task) + + if len(current_batch) >= self.batch_size: + batches.append(current_batch) + current_batch = [] + + if current_batch: + batches.append(current_batch) + + return batches + + +class DAGStrategy(ScheduleStrategy): + """DAG-based scheduling strategy (topological sort).""" + + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """Schedule tasks based on DAG dependencies.""" + dag = TaskGraph() + dag.add_tasks(tasks) + + valid, error = TaskGraph.validate_dag(dag) + if not valid: + logger.error(f"DAG validation failed: {error}") + return await PriorityStrategy().schedule(tasks, **kwargs) + + execution_levels = dag.get_execution_order() + task_map = {task.id: task for task in tasks} + + batches = [] + for level in execution_levels: + batch = [task_map[task_id] for task_id in level if task_id in task_map] + if batch: + batches.append(batch) + + return batches + + +class AutoStrategy(ScheduleStrategy): + """Adaptive scheduling strategy that switches based on task conditions.""" + + def __init__(self, batch_size: int = 10): + self.dag_strategy = DAGStrategy() + self.priority_strategy = PriorityStrategy(batch_size=batch_size) + self.fifo_strategy = FIFOStrategy(batch_size=batch_size) + + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """Adaptively choose scheduling strategy. + + Args: + tasks: Scheduled task list. + """ + + # check dependencies first + if any(task.dependencies for task in tasks): + logger.info("Using DAG strategy, tasks have dependencies") + return await self.dag_strategy.schedule(tasks, **kwargs) + + # check task priority + priority_counts = {} + for task in tasks: + priority_counts[task.priority] = priority_counts.get(task.priority, 0) + 1 + + if len(priority_counts) > 1: + logger.info("Using Priority strategy, tasks varied priorities") + return await self.priority_strategy.schedule(tasks, **kwargs) + + # FIFO + logger.info("Using FIFO strategy") + return await self.fifo_strategy.schedule(tasks, **kwargs) + + +STRATEGY_MAP: Dict[str, Type[ScheduleStrategy]] = { + "auto": AutoStrategy, + "dag": DAGStrategy, + "priority": PriorityStrategy, + "fifo": FIFOStrategy, +} + + +def register_strategy(strategy_type: str, strategy: Type[ScheduleStrategy]) -> bool: + """Register a new strategy. + + Args: + strategy_type: Type of schedule strategy. + strategy: Class of ScheduleStrategy + """ + STRATEGY_MAP[strategy_type] = strategy + return True + + +def create_strategy(strategy_type: Optional[str] = None, **kwargs) -> ScheduleStrategy: + """Create strategy instance.""" + return STRATEGY_MAP.get(strategy_type, AutoStrategy)(**kwargs) + + +class TaskQueue(ABC): + """Base class for task queues, simple API.""" + + @abstractmethod + def push(self, task: ScheduledTask): + """Add a task to the queue.""" + pass + + @abstractmethod + def pop(self) -> Optional[ScheduledTask]: + """Remove and return the next task.""" + pass + + @abstractmethod + def peek(self) -> Optional[ScheduledTask]: + """Return the next task without removing it.""" + pass + + @abstractmethod + def size(self) -> int: + """Return the number of tasks in the queue.""" + pass + + @abstractmethod + def is_empty(self) -> bool: + """Check if queue is empty.""" + pass + + @abstractmethod + def clear(self): + """Clear all tasks from the queue.""" + pass + + +class FIFOTaskQueue(TaskQueue): + """First in first out queue.""" + + def __init__(self): + self.queue = deque() + + def push(self, task: ScheduledTask): + self.queue.append(task) + + def pop(self) -> Optional[ScheduledTask]: + if self.queue: + return self.queue.popleft() + return None + + def peek(self) -> Optional[ScheduledTask]: + if self.queue: + return self.queue[0] + return None + + def size(self) -> int: + return len(self.queue) + + def is_empty(self) -> bool: + return len(self.queue) == 0 + + def clear(self): + self.queue.clear() + + +class PriorityTaskQueue(TaskQueue): + """Priority-based queue using heap.""" + + def __init__(self): + self.heap: List[tuple] = [] + self.counter = 0 + + def push(self, task: ScheduledTask): + # (priority_value, counter, task) + priority_value = task.priority + heapq.heappush(self.heap, (-priority_value, self.counter, task)) + self.counter += 1 + + def pop(self) -> Optional[ScheduledTask]: + if self.heap: + _, _, task = heapq.heappop(self.heap) + return task + return None + + def peek(self) -> Optional[ScheduledTask]: + if self.heap: + _, _, task = self.heap[0] + return task + return None + + def size(self) -> int: + return len(self.heap) + + def is_empty(self) -> bool: + return len(self.heap) == 0 + + def clear(self): + self.heap.clear() + self.counter = 0 + + def update_priority(self, task_id: str, new_priority: int) -> None: + """Update the task priority (rebuild heap). + + Args: + task_id: Task id. + new_priority: The priority of the task. + """ + tasks = [task for _, _, task in self.heap if task.task_id != task_id] + task_to_update = None + for _, _, task in self.heap: + if task.task_id == task_id: + task_to_update = task + break + + self.clear() + for task in tasks: + self.push(task) + + if task_to_update: + task_to_update.priority = new_priority + self.push(task_to_update) + + +class ResourceAwareTaskQueue(TaskQueue): + """Resource-aware queue that considers resource constraints.""" + + def __init__(self, resource_quota: ResourceQuota): + self.priority_queue = PriorityTaskQueue() + self.resource_quota = resource_quota + self.pending_tasks: List[ScheduledTask] = [] + + def push(self, task: ScheduledTask): + """Add task considering resource requirements.""" + self.priority_queue.push(task) + + def pop(self) -> Optional[ScheduledTask]: + """Pop task that fits resource constraints.""" + temp_tasks = [] + + while not self.priority_queue.is_empty(): + task = self.priority_queue.pop() + + if self._can_allocate(task): + for t in temp_tasks: + self.priority_queue.push(t) + return task + else: + temp_tasks.append(task) + + for t in temp_tasks: + self.priority_queue.push(t) + + return None + + def peek(self) -> Optional[ScheduledTask]: + return self.priority_queue.peek() + + def size(self) -> int: + return self.priority_queue.size() + + def is_empty(self) -> bool: + return self.priority_queue.is_empty() + + def clear(self): + self.priority_queue.clear() + self.pending_tasks.clear() + + def _can_allocate(self, task: ScheduledTask) -> bool: + """Check if task can be allocated given current resources.""" + # TODO: based global resource of cluster + return True + + def update_resource_quota(self, quota: ResourceQuota): + """Update resource quota.""" + self.resource_quota = quota diff --git a/aworld/schedule/task_graph.py b/aworld/schedule/task_graph.py new file mode 100644 index 000000000..f1fc2b8e1 --- /dev/null +++ b/aworld/schedule/task_graph.py @@ -0,0 +1,260 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +from collections import deque +from typing import Dict, List, Set, Optional + +from aworld.core.common import TaskStatus +from aworld.core.task import Task +from aworld.logs.util import logger +from aworld.schedule.types import SchedulableTask + + +class TaskGraph: + """Directed Acyclic Graph for task scheduling.""" + + def __init__(self): + self.nodes: Dict[str, Task] = {} + self.completed: Set[str] = set() + # predecessor[task_id] = set of tasks that must complete before task_id + self.predecessor: Dict[str, Set[str]] = {} + # successor[task_id] = set of tasks that depend on task_id + self.successor: Dict[str, Set[str]] = {} + + def add_tasks(self, tasks: List[Task]): + for task in tasks: + self.add_task(task) + + def add_task(self, task: Task): + """Add a task to the Graph.""" + if task.id in self.nodes: + logger.warning(f"Task {task.id} already exists in DAG") + return + + self.nodes[task.id] = task + + if task.id not in self.predecessor: + self.predecessor[task.id] = set() + if task.id not in self.successor: + self.successor[task.id] = set() + + if hasattr(task, "dependencies"): + for dep_id in task.dependencies: + self.predecessor[task.id].add(dep_id) + + if dep_id not in self.successor: + self.successor[dep_id] = set() + self.successor[dep_id].add(task.id) + + if dep_id not in self.nodes and dep_id not in self.completed: + logger.warning(f"Dependency {dep_id} not found for task {task.id}") + + def get_ready_tasks(self) -> List[Task]: + """Get all tasks that are ready to execute (no pending dependencies).""" + ready_tasks = [] + for task_id, node in self.nodes.items(): + if node.task_status == TaskStatus.INIT: + if task_id in self.predecessor: + pending_predecessors = self.predecessor[task_id] - self.completed + if not pending_predecessors: + ready_tasks.append(node) + else: + ready_tasks.append(node) + return ready_tasks + + def mark_completed(self, task_id: str): + """Mark a task as completed and update dependent tasks.""" + if task_id not in self.nodes: + logger.warning(f"Task {task_id} not found in DAG") + return + + self.completed.add(task_id) + + if task_id in self.predecessor: + del self.predecessor[task_id] + if task_id in self.successor: + del self.successor[task_id] + + del self.nodes[task_id] + + def mark_failed(self, task_id: str): + """Mark a task as failed and handle dependents.""" + if task_id not in self.nodes: + return + + if task_id in self.successor: + for dependent_id in self.successor[task_id]: + if dependent_id in self.nodes: + self.nodes[dependent_id].task_status = TaskStatus.FAILED + + if task_id in self.predecessor: + del self.predecessor[task_id] + if task_id in self.successor: + del self.successor[task_id] + + del self.nodes[task_id] + + def get_execution_order(self) -> List[List[str]]: + """Get topological execution order (by levels) using Kahn's algorithm.""" + + in_degree = {} + for task_id in self.nodes: + if task_id in self.predecessor: + pending = self.predecessor[task_id] - self.completed + in_degree[task_id] = len(pending) + else: + in_degree[task_id] = 0 + + queue = deque([task_id for task_id, degree in in_degree.items() if degree == 0]) + levels = [] + + while queue: + current_level = list(queue) + levels.append(current_level) + queue.clear() + + for task_id in current_level: + if task_id not in self.successor: + continue + + for successor_id in self.successor[task_id]: + if successor_id in in_degree: + in_degree[successor_id] -= 1 + if in_degree[successor_id] == 0: + queue.append(successor_id) + + return levels + + def has_cycle(self) -> bool: + """Check if DAG has a cycle using DFS.""" + visited = set() + rec_stack = set() + + def dfs(node_id: str) -> bool: + visited.add(node_id) + rec_stack.add(node_id) + + if node_id in self.successor: + for successor_id in self.successor[node_id]: + if successor_id not in visited: + if dfs(successor_id): + return True + elif successor_id in rec_stack: + return True + + rec_stack.remove(node_id) + return False + + for task_id in self.nodes: + if task_id not in visited: + if dfs(task_id): + return True + + return False + + def get_critical_path(self) -> List[str]: + """Get critical path (longest path) in DAG.""" + longest_path = {} + path_to = {} + + def calculate_longest_path(task_id: str) -> int: + if task_id in longest_path: + return longest_path[task_id] + + if task_id not in self.nodes: + longest_path[task_id] = 0 + return 0 + + if task_id not in self.successor or not self.successor[task_id]: + longest_path[task_id] = 1 + return 1 + + max_length = 0 + next_task = None + for successor_id in self.successor[task_id]: + length = calculate_longest_path(successor_id) + if length > max_length: + max_length = length + next_task = successor_id + + longest_path[task_id] = max_length + 1 + if next_task: + path_to[task_id] = next_task + return longest_path[task_id] + + for task_id in self.nodes: + calculate_longest_path(task_id) + + if not longest_path: + return [] + + start = max(longest_path, key=longest_path.get) + + path = [start] + current = start + while current in path_to: + current = path_to[current] + path.append(current) + + return path + + def in_degree(self) -> Dict[str, int]: + in_degree = {} + for k, _ in self.nodes.items(): + tasks = self.predecessor[k] + in_degree[k] = len(tasks) + return in_degree + + def out_degree(self) -> Dict[str, int]: + out_degree = {} + for k, _ in self.nodes.items(): + tasks = self.successor[k] + out_degree[k] = len(tasks) + return out_degree + + def get_statistics(self) -> Dict: + ready_count = 0 + for task_id in self.nodes: + if task_id in self.predecessor: + pending = self.predecessor[task_id] - self.completed + if not pending: + ready_count += 1 + else: + ready_count += 1 + + return { + "total_tasks": len(self.nodes), + "completed_tasks": len(self.completed), + "pending_tasks": sum(1 for n in self.nodes.values() if n.task_status == TaskStatus.INIT), + "ready_tasks": ready_count, + "max_depth": len(self.get_execution_order()), + "has_cycle": self.has_cycle(), + } + + @staticmethod + def validate_dag(dag: 'TaskGraph') -> tuple[bool, Optional[str]]: + """Validate DAG structure.""" + # Check for cycles + if dag.has_cycle(): + return False, "DAG contains cycles" + + # Check for orphaned dependencies + all_task_ids = set(dag.nodes.keys()) | dag.completed + for task_id in dag.nodes: + if task_id in dag.predecessor: + for dep_id in dag.predecessor[task_id]: + if dep_id not in all_task_ids: + return False, f"Task {task_id} has undefined dependency {dep_id}" + + return True, None + + @staticmethod + def validate_task_dependencies(task: SchedulableTask, existing_tasks: Set[str]) -> tuple[bool, Optional[str]]: + """Validate task dependencies before adding to DAG.""" + for dep_id in task.dependencies: + if dep_id == task.id: + return False, "Task cannot depend on itself" + + if dep_id not in existing_tasks: + logger.warning(f"Dependency {dep_id} not found in existing tasks") + + return True, None diff --git a/aworld/schedule/types.py b/aworld/schedule/types.py new file mode 100644 index 000000000..e262717bc --- /dev/null +++ b/aworld/schedule/types.py @@ -0,0 +1,263 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import time +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any, Set + +from aworld.core.common import TaskStatus +from aworld.core.task import Task + + +@dataclass +class SchedulableTask(Task): + """Instant schedulable task with dependencies and priority.""" + priority: int = 0 + retry_count: int = 0 + dependencies: List[str] = field(default_factory=list) + + # estimated resources + estimated_cpu: float = 0.0 + estimated_memory: float = 0.0 + estimated_time: float = 0.0 + + created_at: float = field(default_factory=time.time) + scheduled_at: Optional[float] = None + started_at: Optional[float] = None + completed_at: Optional[float] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __lt__(self, other): + if not isinstance(other, SchedulableTask): + return NotImplemented + return self.priority < other.priority + + @property + def status(self) -> TaskStatus: + return self.task_status + + def is_ready(self, completed_tasks: Set[str]) -> bool: + """Check if task is ready to execute (all dependencies completed).""" + return all(dep_id in completed_tasks for dep_id in self.dependencies) + + def can_retry(self) -> bool: + """Check if task can be retried.""" + return self.retry_count < self.max_retry_count + + def duration(self) -> float: + """Get task execution duration.""" + if self.started_at and self.completed_at: + return self.completed_at - self.started_at + return 0.0 + + +@dataclass +class ScheduledTask(SchedulableTask): + """Unified scheduled task supporting cron-like scheduling, one-time execution, and delayed execution. + + Scheduling modes (priority order): + 1. Cron expression: Use cron_expression for periodic tasks (e.g., "0 0 * * *" for daily at midnight) + 2. One-time scheduled: Use scheduled_time for one-time execution at specific time + 3. Delayed execution: Use delay for execution after specified seconds + 4. Instant: If none of above is set, execute immediately + + Examples: + # Cron periodic task (every 5 minutes) + ScheduledTask(..., cron_expression="*/5 * * * *") + + # One-time scheduled task (at specific timestamp) + ScheduledTask(..., scheduled_time=1678886400.0) + + # Delayed task (execute after 60 seconds) + ScheduledTask(..., delay=60.0) + + # Instant task (execute immediately) + ScheduledTask(...) + """ + + # Cron expression for periodic tasks (e.g., "*/5 * * * *" for every 5 minutes) + # Format: minute hour day month weekday + cron_expression: Optional[str] = None + + # Scheduled time for one-time execution (Unix timestamp) + scheduled_time: Optional[float] = None + + # Delay in seconds for delayed execution + delay: Optional[float] = None + + # Time range constraints + start_time: Optional[float] = None # Start time (Unix timestamp) + end_time: Optional[float] = None # End time (Unix timestamp) + + # Execution limits + max_executions: Optional[int] = None # Max number of executions (for cron tasks) + execution_count: int = 0 # Current execution count + + # Next execution time (calculated from cron/scheduled_time/delay) + next_run_time: Optional[float] = None + + def __post_init__(self): + """Initialize next_run_time based on scheduling mode.""" + if self.next_run_time is None: + if self.cron_expression: + # Calculate next run time from cron expression + self.next_run_time = self._calculate_next_cron_time() + elif self.scheduled_time is not None: + # One-time scheduled task + self.next_run_time = self.scheduled_time + elif self.delay is not None: + # Delayed task + self.next_run_time = self.created_at + self.delay + else: + # Instant task (no scheduling) + self.next_run_time = self.created_at + + if self.start_time is None: + self.start_time = self.created_at + + def __lt__(self, other): + if not isinstance(other, ScheduledTask): + return NotImplemented + # Compare by next_run_time, then by priority + if self.next_run_time != other.next_run_time: + return (self.next_run_time or 0) < (other.next_run_time or 0) + return self.priority < other.priority + + @property + def is_periodic(self) -> bool: + """Check if this is a periodic task.""" + return self.cron_expression is not None + + @property + def is_one_time(self) -> bool: + """Check if this is a one-time task.""" + return self.cron_expression is None + + def is_ready(self, completed_tasks: Set[str], current_time: Optional[float] = None) -> bool: + """Check if task is ready to execute.""" + if current_time is None: + current_time = time.time() + + # Check time range + if self.start_time is not None and current_time < self.start_time: + return False + if self.end_time is not None and current_time > self.end_time: + return False + + # Check max executions + if self.max_executions is not None and self.execution_count >= self.max_executions: + return False + + # Check if time has come + time_ready = self.next_run_time is not None and current_time >= self.next_run_time + + # Check dependencies + deps_ready = all(dep_id in completed_tasks for dep_id in self.dependencies) + + return time_ready and deps_ready + + def update_next_run_time(self): + """Update next run time after execution.""" + self.execution_count += 1 + + if self.cron_expression: + # Periodic task: calculate next run time from cron + self.next_run_time = self._calculate_next_cron_time() + else: + # One-time task: no next run + self.next_run_time = None + + def _calculate_next_cron_time(self, from_time: Optional[float] = None) -> Optional[float]: + """Calculate next execution time from cron expression. + + Simplified cron format: minute hour day month weekday + Supports: numbers, *, */n (every n), ranges (1-5), lists (1,3,5) + """ + if not self.cron_expression: + return None + + if from_time is None: + from_time = time.time() + + # Try to use croniter library if available + try: + from croniter import croniter + cron = croniter(self.cron_expression, from_time) + return cron.get_next() + except ImportError: + # Fallback: simple interval-based calculation for */n patterns + parts = self.cron_expression.split() + if len(parts) >= 1 and parts[0].startswith('*/'): + try: + interval_minutes = int(parts[0][2:]) + return from_time + (interval_minutes * 60) + except ValueError: + pass + # Default: execute once after 1 hour + return from_time + 3600 + + +@dataclass +class TaskStatistics: + """Global scheduling state.""" + total_tasks: int = 0 + pending_tasks: int = 0 + running_tasks: int = 0 + completed_tasks: int = 0 + failed_tasks: int = 0 + abnormal_tasks: int = 0 + + used_cpu: float = 0.0 + used_memory: float = 0.0 + used_cost: float = 0.0 + used_time: float = 0.0 + current_concurrent: int = 0 + + avg_wait_time: float = 0.0 + avg_execution_time: float = 0.0 + throughput: float = 0.0 + + def update_from_task(self, task: Task): + """Update state from task.""" + if task.task_status == TaskStatus.INIT: + self.pending_tasks += 1 + elif task.task_status == TaskStatus.RUNNING: + self.running_tasks += 1 + self.current_concurrent += 1 + elif task.task_status == TaskStatus.SUCCESS: + self.completed_tasks += 1 + self.current_concurrent = max(0, self.current_concurrent - 1) + elif task.task_status == TaskStatus.FAILED: + self.failed_tasks += 1 + self.current_concurrent = max(0, self.current_concurrent - 1) + elif task.task_status in (TaskStatus.CANCELLED, TaskStatus.INTERRUPTED, TaskStatus.TIMEOUT): + self.abnormal_tasks += 1 + self.current_concurrent = max(0, self.current_concurrent - 1) + + self.total_tasks += 1 + + +@dataclass +class ResourceQuota: + """Resource quota of the machine for task execution.""" + max_cpu: float = 0.0 + max_memory: float = 0.0 + max_time: float = 0.0 + max_concurrent: int = 0 + + def is_available(self, + used_cpu: float = 0, + used_memory: float = 0, + used_time: float = 0, + current_concurrent: int = 0) -> bool: + """Check if resources are available.""" + checks = [] + if self.max_cpu > 0: + checks.append(used_cpu < self.max_cpu) + if self.max_memory > 0: + checks.append(used_memory < self.max_memory) + if self.max_time > 0: + checks.append(used_time < self.max_time) + if self.max_concurrent > 0: + checks.append(current_concurrent < self.max_concurrent) + + return all(checks) if checks else True diff --git a/tests/memory/agent/self_evolving_agent.py b/tests/memory/agent/self_evolving_agent.py index fadc69541..ce2d9837f 100644 --- a/tests/memory/agent/self_evolving_agent.py +++ b/tests/memory/agent/self_evolving_agent.py @@ -16,7 +16,6 @@ from aworld.memory.main import MemoryFactory from aworld.memory.models import LongTermMemoryTriggerParams, MemoryAIMessage, MessageMetadata, UserProfile, \ MemoryHumanMessage -from aworld.memory.utils import build_history_context from aworld.output import AworldUI from aworld.output.utils import load_workspace from aworld.prompt import Prompt @@ -25,6 +24,19 @@ from tests.memory.prompts import SELF_EVOLVING_USER_INPUT_REWRITE_PROMPT, RESEARCH_PROMPT +def build_history_context(history_messages: list[MemoryItem]) -> str: + """ + Build history context from history messages. + """ + history_context = "" + for message in history_messages: + if message.role == "user": + history_context += f"User: {message.content}\n" + else: + history_context += f"Agent: {message.content}\n" + return history_context + + class SuperAgent: """ Super agent