From 8c5a76d5a4a0a1212459075415115cae049d5693 Mon Sep 17 00:00:00 2001 From: ganrunsheng Date: Thu, 5 Mar 2026 14:06:15 +0800 Subject: [PATCH 01/10] schedule related structure; schedule strategy --- aworld/schedule/__init__.py | 2 + aworld/schedule/strategy.py | 341 ++++++++++++++++++++++++++++++++++ aworld/schedule/task_graph.py | 260 ++++++++++++++++++++++++++ aworld/schedule/types.py | 263 ++++++++++++++++++++++++++ 4 files changed, 866 insertions(+) create mode 100644 aworld/schedule/__init__.py create mode 100644 aworld/schedule/strategy.py create mode 100644 aworld/schedule/task_graph.py create mode 100644 aworld/schedule/types.py diff --git a/aworld/schedule/__init__.py b/aworld/schedule/__init__.py new file mode 100644 index 000000000..e8b7477a3 --- /dev/null +++ b/aworld/schedule/__init__.py @@ -0,0 +1,2 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. diff --git a/aworld/schedule/strategy.py b/aworld/schedule/strategy.py new file mode 100644 index 000000000..98a819c26 --- /dev/null +++ b/aworld/schedule/strategy.py @@ -0,0 +1,341 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import heapq +from abc import abstractmethod, ABC +from collections import deque +from typing import List, Optional, Dict, Type + +from aworld.logs.util import logger +from aworld.schedule.task_graph import TaskGraph +from aworld.schedule.types import ScheduledTask, ResourceQuota + + +class ScheduleStrategy(ABC): + """Abstract base class for scheduling strategies.""" + + @abstractmethod + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """ + Schedule tasks and return execution batches. + + Returns: + List of task batches, where each batch can be executed in parallel. + """ + + +class FIFOStrategy(ScheduleStrategy): + """First-In-First-Out scheduling default strategy.""" + + def __init__(self, batch_size: int = 10, **kwargs): + self.batch_size = batch_size + + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """Schedule tasks in FIFO order.""" + queue = FIFOTaskQueue() + for task in tasks: + queue.push(task) + + batches = [] + current_batch = [] + + while not queue.is_empty(): + task = queue.pop() + if task: + current_batch.append(task) + + if len(current_batch) >= self.batch_size: + batches.append(current_batch) + current_batch = [] + + if current_batch: + batches.append(current_batch) + + return batches + + +class PriorityStrategy(ScheduleStrategy): + """Priority-based scheduling strategy.""" + + def __init__(self, batch_size: int = 10, **kwargs): + self.batch_size = batch_size + + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """Schedule tasks by priority.""" + queue = PriorityTaskQueue() + for task in tasks: + queue.push(task) + + batches = [] + current_batch = [] + + while not queue.is_empty(): + task = queue.pop() + if task: + current_batch.append(task) + + if len(current_batch) >= self.batch_size: + batches.append(current_batch) + current_batch = [] + + if current_batch: + batches.append(current_batch) + + return batches + + +class DAGStrategy(ScheduleStrategy): + """DAG-based scheduling strategy (topological sort).""" + + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """Schedule tasks based on DAG dependencies.""" + dag = TaskGraph() + dag.add_tasks(tasks) + + valid, error = TaskGraph.validate_dag(dag) + if not valid: + logger.error(f"DAG validation failed: {error}") + return await PriorityStrategy().schedule(tasks, **kwargs) + + execution_levels = dag.get_execution_order() + task_map = {task.id: task for task in tasks} + + batches = [] + for level in execution_levels: + batch = [task_map[task_id] for task_id in level if task_id in task_map] + if batch: + batches.append(batch) + + return batches + + +class AutoStrategy(ScheduleStrategy): + """Adaptive scheduling strategy that switches based on task conditions.""" + + def __init__(self, batch_size: int = 10): + self.dag_strategy = DAGStrategy() + self.priority_strategy = PriorityStrategy(batch_size=batch_size) + self.fifo_strategy = FIFOStrategy(batch_size=batch_size) + + async def schedule(self, tasks: List[ScheduledTask], **kwargs) -> List[List[ScheduledTask]]: + """Adaptively choose scheduling strategy. + + Args: + tasks: Scheduled task list. + """ + + # check dependencies first + if any(task.dependencies for task in tasks): + logger.info("Using DAG strategy, tasks have dependencies") + return await self.dag_strategy.schedule(tasks, **kwargs) + + # check task priority + priority_counts = {} + for task in tasks: + priority_counts[task.priority] = priority_counts.get(task.priority, 0) + 1 + + if len(priority_counts) > 1: + logger.info("Using Priority strategy, tasks varied priorities") + return await self.priority_strategy.schedule(tasks, **kwargs) + + # FIFO + logger.info("Using FIFO strategy") + return await self.fifo_strategy.schedule(tasks, **kwargs) + + +STRATEGY_MAP: Dict[str, Type[ScheduleStrategy]] = { + "auto": AutoStrategy, + "dag": DAGStrategy, + "priority": PriorityStrategy, + "fifo": FIFOStrategy, +} + + +def register_strategy(strategy_type: str, strategy: Type[ScheduleStrategy]) -> bool: + """Register a new strategy. + + Args: + strategy_type: Type of schedule strategy. + strategy: Class of ScheduleStrategy + """ + STRATEGY_MAP[strategy_type] = strategy + return True + + +def create_strategy(strategy_type: Optional[str] = None, **kwargs) -> ScheduleStrategy: + """Create strategy instance.""" + return STRATEGY_MAP.get(strategy_type, AutoStrategy())(**kwargs) + + +class TaskQueue(ABC): + """Base class for task queues, simple API.""" + + @abstractmethod + def push(self, task: ScheduledTask): + """Add a task to the queue.""" + pass + + @abstractmethod + def pop(self) -> Optional[ScheduledTask]: + """Remove and return the next task.""" + pass + + @abstractmethod + def peek(self) -> Optional[ScheduledTask]: + """Return the next task without removing it.""" + pass + + @abstractmethod + def size(self) -> int: + """Return the number of tasks in the queue.""" + pass + + @abstractmethod + def is_empty(self) -> bool: + """Check if queue is empty.""" + pass + + @abstractmethod + def clear(self): + """Clear all tasks from the queue.""" + pass + + +class FIFOTaskQueue(TaskQueue): + """First in first out queue.""" + + def __init__(self): + self.queue = deque() + + def push(self, task: ScheduledTask): + self.queue.append(task) + + def pop(self) -> Optional[ScheduledTask]: + if self.queue: + return self.queue.popleft() + return None + + def peek(self) -> Optional[ScheduledTask]: + if self.queue: + return self.queue[0] + return None + + def size(self) -> int: + return len(self.queue) + + def is_empty(self) -> bool: + return len(self.queue) == 0 + + def clear(self): + self.queue.clear() + + +class PriorityTaskQueue(TaskQueue): + """Priority-based queue using heap.""" + + def __init__(self): + self.heap: List[tuple] = [] + self.counter = 0 + + def push(self, task: ScheduledTask): + # (priority_value, counter, task) + priority_value = task.priority + heapq.heappush(self.heap, (priority_value, self.counter, task)) + self.counter += 1 + + def pop(self) -> Optional[ScheduledTask]: + if self.heap: + _, _, task = heapq.heappop(self.heap) + return task + return None + + def peek(self) -> Optional[ScheduledTask]: + if self.heap: + _, _, task = self.heap[0] + return task + return None + + def size(self) -> int: + return len(self.heap) + + def is_empty(self) -> bool: + return len(self.heap) == 0 + + def clear(self): + self.heap.clear() + self.counter = 0 + + def update_priority(self, task_id: str, new_priority: int) -> None: + """Update the task priority (rebuild heap). + + Args: + task_id: Task id. + new_priority: The priority of the task. + """ + tasks = [task for _, _, task in self.heap if task.task_id != task_id] + task_to_update = None + for _, _, task in self.heap: + if task.task_id == task_id: + task_to_update = task + break + + self.clear() + for task in tasks: + self.push(task) + + if task_to_update: + task_to_update.priority = new_priority + self.push(task_to_update) + + +class ResourceAwareTaskQueue(TaskQueue): + """Resource-aware queue that considers resource constraints.""" + + def __init__(self, resource_quota: ResourceQuota): + self.priority_queue = PriorityTaskQueue() + self.resource_quota = resource_quota + self.pending_tasks: List[ScheduledTask] = [] + + def push(self, task: ScheduledTask): + """Add task considering resource requirements.""" + self.priority_queue.push(task) + + def pop(self) -> Optional[ScheduledTask]: + """Pop task that fits resource constraints.""" + temp_tasks = [] + + while not self.priority_queue.is_empty(): + task = self.priority_queue.pop() + + if self._can_allocate(task): + for t in temp_tasks: + self.priority_queue.push(t) + return task + else: + temp_tasks.append(task) + + for t in temp_tasks: + self.priority_queue.push(t) + + return None + + def peek(self) -> Optional[ScheduledTask]: + return self.priority_queue.peek() + + def size(self) -> int: + return self.priority_queue.size() + + def is_empty(self) -> bool: + return self.priority_queue.is_empty() + + def clear(self): + self.priority_queue.clear() + self.pending_tasks.clear() + + def _can_allocate(self, task: ScheduledTask) -> bool: + """Check if task can be allocated given current resources.""" + # TODO: based global resource of cluster + return True + + def update_resource_quota(self, quota: ResourceQuota): + """Update resource quota.""" + self.resource_quota = quota diff --git a/aworld/schedule/task_graph.py b/aworld/schedule/task_graph.py new file mode 100644 index 000000000..70db235f8 --- /dev/null +++ b/aworld/schedule/task_graph.py @@ -0,0 +1,260 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +from collections import deque +from typing import Dict, List, Set, Optional + +from aworld.core.common import TaskStatusValue +from aworld.core.task import Task +from aworld.logs.util import logger +from aworld.schedule.types import InstantTask + + +class TaskGraph: + """Directed Acyclic Graph for task scheduling.""" + + def __init__(self): + self.nodes: Dict[str, Task] = {} + self.completed: Set[str] = set() + # predecessor[task_id] = set of tasks that must complete before task_id + self.predecessor: Dict[str, Set[str]] = {} + # successor[task_id] = set of tasks that depend on task_id + self.successor: Dict[str, Set[str]] = {} + + def add_tasks(self, tasks: List[Task]): + for task in tasks: + self.add_task(task) + + def add_task(self, task: Task): + """Add a task to the Graph.""" + if task.id in self.nodes: + logger.warning(f"Task {task.id} already exists in DAG") + return + + self.nodes[task.id] = task + + if task.id not in self.predecessor: + self.predecessor[task.id] = set() + if task.id not in self.successor: + self.successor[task.id] = set() + + if hasattr(task, "dependencies"): + for dep_id in task.dependencies: + self.predecessor[task.id].add(dep_id) + + if dep_id not in self.successor: + self.successor[dep_id] = set() + self.successor[dep_id].add(task.id) + + if dep_id not in self.nodes and dep_id not in self.completed: + logger.warning(f"Dependency {dep_id} not found for task {task.id}") + + def get_ready_tasks(self) -> List[Task]: + """Get all tasks that are ready to execute (no pending dependencies).""" + ready_tasks = [] + for task_id, node in self.nodes.items(): + if node.task_status == TaskStatusValue.INIT: + if task_id in self.predecessor: + pending_predecessors = self.predecessor[task_id] - self.completed + if not pending_predecessors: + ready_tasks.append(node) + else: + ready_tasks.append(node) + return ready_tasks + + def mark_completed(self, task_id: str): + """Mark a task as completed and update dependent tasks.""" + if task_id not in self.nodes: + logger.warning(f"Task {task_id} not found in DAG") + return + + self.completed.add(task_id) + + if task_id in self.predecessor: + del self.predecessor[task_id] + if task_id in self.successor: + del self.successor[task_id] + + del self.nodes[task_id] + + def mark_failed(self, task_id: str): + """Mark a task as failed and handle dependents.""" + if task_id not in self.nodes: + return + + if task_id in self.successor: + for dependent_id in self.successor[task_id]: + if dependent_id in self.nodes: + self.nodes[dependent_id].task_status = TaskStatusValue.FAILED + + if task_id in self.predecessor: + del self.predecessor[task_id] + if task_id in self.successor: + del self.successor[task_id] + + del self.nodes[task_id] + + def get_execution_order(self) -> List[List[str]]: + """Get topological execution order (by levels) using Kahn's algorithm.""" + + in_degree = {} + for task_id in self.nodes: + if task_id in self.predecessor: + pending = self.predecessor[task_id] - self.completed + in_degree[task_id] = len(pending) + else: + in_degree[task_id] = 0 + + queue = deque([task_id for task_id, degree in in_degree.items() if degree == 0]) + levels = [] + + while queue: + current_level = list(queue) + levels.append(current_level) + queue.clear() + + for task_id in current_level: + if task_id not in self.successor: + continue + + for successor_id in self.successor[task_id]: + if successor_id in in_degree: + in_degree[successor_id] -= 1 + if in_degree[successor_id] == 0: + queue.append(successor_id) + + return levels + + def has_cycle(self) -> bool: + """Check if DAG has a cycle using DFS.""" + visited = set() + rec_stack = set() + + def dfs(node_id: str) -> bool: + visited.add(node_id) + rec_stack.add(node_id) + + if node_id in self.successor: + for successor_id in self.successor[node_id]: + if successor_id not in visited: + if dfs(successor_id): + return True + elif successor_id in rec_stack: + return True + + rec_stack.remove(node_id) + return False + + for task_id in self.nodes: + if task_id not in visited: + if dfs(task_id): + return True + + return False + + def get_critical_path(self) -> List[str]: + """Get critical path (longest path) in DAG.""" + longest_path = {} + path_to = {} + + def calculate_longest_path(task_id: str) -> int: + if task_id in longest_path: + return longest_path[task_id] + + if task_id not in self.nodes: + longest_path[task_id] = 0 + return 0 + + if task_id not in self.successor or not self.successor[task_id]: + longest_path[task_id] = 1 + return 1 + + max_length = 0 + next_task = None + for successor_id in self.successor[task_id]: + length = calculate_longest_path(successor_id) + if length > max_length: + max_length = length + next_task = successor_id + + longest_path[task_id] = max_length + 1 + if next_task: + path_to[task_id] = next_task + return longest_path[task_id] + + for task_id in self.nodes: + calculate_longest_path(task_id) + + if not longest_path: + return [] + + start = max(longest_path, key=longest_path.get) + + path = [start] + current = start + while current in path_to: + current = path_to[current] + path.append(current) + + return path + + def in_degree(self) -> Dict[str, int]: + in_degree = {} + for k, _ in self.nodes.items(): + tasks = self.predecessor[k] + in_degree[k] = len(tasks) + return in_degree + + def out_degree(self) -> Dict[str, int]: + out_degree = {} + for k, _ in self.nodes.items(): + tasks = self.successor[k] + out_degree[k] = len(tasks) + return out_degree + + def get_statistics(self) -> Dict: + ready_count = 0 + for task_id in self.nodes: + if task_id in self.predecessor: + pending = self.predecessor[task_id] - self.completed + if not pending: + ready_count += 1 + else: + ready_count += 1 + + return { + "total_tasks": len(self.nodes), + "completed_tasks": len(self.completed), + "pending_tasks": sum(1 for n in self.nodes.values() if n.task_status == TaskStatusValue.INIT), + "ready_tasks": ready_count, + "max_depth": len(self.get_execution_order()), + "has_cycle": self.has_cycle(), + } + + @staticmethod + def validate_dag(dag: 'TaskGraph') -> tuple[bool, Optional[str]]: + """Validate DAG structure.""" + # Check for cycles + if dag.has_cycle(): + return False, "DAG contains cycles" + + # Check for orphaned dependencies + all_task_ids = set(dag.nodes.keys()) | dag.completed + for task_id in dag.nodes: + if task_id in dag.predecessor: + for dep_id in dag.predecessor[task_id]: + if dep_id not in all_task_ids: + return False, f"Task {task_id} has undefined dependency {dep_id}" + + return True, None + + @staticmethod + def validate_task_dependencies(task: InstantTask, existing_tasks: Set[str]) -> tuple[bool, Optional[str]]: + """Validate task dependencies before adding to DAG.""" + for dep_id in task.dependencies: + if dep_id == task.id: + return False, "Task cannot depend on itself" + + if dep_id not in existing_tasks: + logger.warning(f"Dependency {dep_id} not found in existing tasks") + + return True, None diff --git a/aworld/schedule/types.py b/aworld/schedule/types.py new file mode 100644 index 000000000..77d5cceb5 --- /dev/null +++ b/aworld/schedule/types.py @@ -0,0 +1,263 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import time +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any, Set + +from aworld.core.common import TaskStatus, TaskStatusValue +from aworld.core.task import Task + + +@dataclass +class InstantTask(Task): + """Instant task with dependencies and metadata.""" + priority: int = 0 + retry_count: int = 0 + dependencies: List[str] = field(default_factory=list) + + # estimated resources + estimated_cpu: float = 0.0 + estimated_memory: float = 0.0 + estimated_time: float = 0.0 + + created_at: float = field(default_factory=time.time) + scheduled_at: Optional[float] = None + started_at: Optional[float] = None + completed_at: Optional[float] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __lt__(self, other): + if not isinstance(other, InstantTask): + return NotImplemented + return self.priority < other.priority + + @property + def status(self) -> TaskStatus: + return self.task_status + + def is_ready(self, completed_tasks: Set[str]) -> bool: + """Check if task is ready to execute (all dependencies completed).""" + return all(dep_id in completed_tasks for dep_id in self.dependencies) + + def can_retry(self) -> bool: + """Check if task can be retried.""" + return self.retry_count < self.max_retry_count + + def duration(self) -> float: + """Get task execution duration.""" + if self.started_at and self.completed_at: + return self.completed_at - self.started_at + return 0.0 + + +@dataclass +class ScheduledTask(InstantTask): + """Unified scheduled task supporting cron-like scheduling, one-time execution, and delayed execution. + + Scheduling modes (priority order): + 1. Cron expression: Use cron_expression for periodic tasks (e.g., "0 0 * * *" for daily at midnight) + 2. One-time scheduled: Use scheduled_time for one-time execution at specific time + 3. Delayed execution: Use delay for execution after specified seconds + 4. Instant: If none of above is set, execute immediately + + Examples: + # Cron periodic task (every 5 minutes) + ScheduledTask(..., cron_expression="*/5 * * * *") + + # One-time scheduled task (at specific timestamp) + ScheduledTask(..., scheduled_time=1678886400.0) + + # Delayed task (execute after 60 seconds) + ScheduledTask(..., delay=60.0) + + # Instant task (execute immediately) + ScheduledTask(...) + """ + + # Cron expression for periodic tasks (e.g., "*/5 * * * *" for every 5 minutes) + # Format: minute hour day month weekday + cron_expression: Optional[str] = None + + # Scheduled time for one-time execution (Unix timestamp) + scheduled_time: Optional[float] = None + + # Delay in seconds for delayed execution + delay: Optional[float] = None + + # Time range constraints + start_time: Optional[float] = None # Start time (Unix timestamp) + end_time: Optional[float] = None # End time (Unix timestamp) + + # Execution limits + max_executions: Optional[int] = None # Max number of executions (for cron tasks) + execution_count: int = 0 # Current execution count + + # Next execution time (calculated from cron/scheduled_time/delay) + next_run_time: Optional[float] = None + + def __post_init__(self): + """Initialize next_run_time based on scheduling mode.""" + if self.next_run_time is None: + if self.cron_expression: + # Calculate next run time from cron expression + self.next_run_time = self._calculate_next_cron_time() + elif self.scheduled_time is not None: + # One-time scheduled task + self.next_run_time = self.scheduled_time + elif self.delay is not None: + # Delayed task + self.next_run_time = self.created_at + self.delay + else: + # Instant task (no scheduling) + self.next_run_time = self.created_at + + if self.start_time is None: + self.start_time = self.created_at + + def __lt__(self, other): + if not isinstance(other, ScheduledTask): + return NotImplemented + # Compare by next_run_time, then by priority + if self.next_run_time != other.next_run_time: + return (self.next_run_time or 0) < (other.next_run_time or 0) + return self.priority < other.priority + + @property + def is_periodic(self) -> bool: + """Check if this is a periodic task.""" + return self.cron_expression is not None + + @property + def is_one_time(self) -> bool: + """Check if this is a one-time task.""" + return self.cron_expression is None + + def is_ready(self, completed_tasks: Set[str], current_time: Optional[float] = None) -> bool: + """Check if task is ready to execute.""" + if current_time is None: + current_time = time.time() + + # Check time range + if self.start_time is not None and current_time < self.start_time: + return False + if self.end_time is not None and current_time > self.end_time: + return False + + # Check max executions + if self.max_executions is not None and self.execution_count >= self.max_executions: + return False + + # Check if time has come + time_ready = self.next_run_time is not None and current_time >= self.next_run_time + + # Check dependencies + deps_ready = all(dep_id in completed_tasks for dep_id in self.dependencies) + + return time_ready and deps_ready + + def update_next_run_time(self): + """Update next run time after execution.""" + self.execution_count += 1 + + if self.cron_expression: + # Periodic task: calculate next run time from cron + self.next_run_time = self._calculate_next_cron_time() + else: + # One-time task: no next run + self.next_run_time = None + + def _calculate_next_cron_time(self, from_time: Optional[float] = None) -> Optional[float]: + """Calculate next execution time from cron expression. + + Simplified cron format: minute hour day month weekday + Supports: numbers, *, */n (every n), ranges (1-5), lists (1,3,5) + """ + if not self.cron_expression: + return None + + if from_time is None: + from_time = time.time() + + # Try to use croniter library if available + try: + from croniter import croniter + cron = croniter(self.cron_expression, from_time) + return cron.get_next() + except ImportError: + # Fallback: simple interval-based calculation for */n patterns + parts = self.cron_expression.split() + if len(parts) >= 1 and parts[0].startswith('*/'): + try: + interval_minutes = int(parts[0][2:]) + return from_time + (interval_minutes * 60) + except ValueError: + pass + # Default: execute once after 1 hour + return from_time + 3600 + + +@dataclass +class ScheduledTaskStatistics: + """Global scheduling state.""" + total_tasks: int = 0 + pending_tasks: int = 0 + running_tasks: int = 0 + completed_tasks: int = 0 + failed_tasks: int = 0 + abnormal_tasks: int = 0 + + used_cpu: float = 0.0 + used_memory: float = 0.0 + used_cost: float = 0.0 + used_time: float = 0.0 + current_concurrent: int = 0 + + avg_wait_time: float = 0.0 + avg_execution_time: float = 0.0 + throughput: float = 0.0 + + def update_from_task(self, task: Task): + """Update state from task.""" + if task.task_status == TaskStatusValue.INIT: + self.pending_tasks += 1 + elif task.task_status == TaskStatusValue.RUNNING: + self.running_tasks += 1 + self.current_concurrent += 1 + elif task.task_status == TaskStatusValue.SUCCESS: + self.completed_tasks += 1 + self.current_concurrent = max(0, self.current_concurrent - 1) + elif task.task_status == TaskStatusValue.FAILED: + self.failed_tasks += 1 + self.current_concurrent = max(0, self.current_concurrent - 1) + elif task.task_status in (TaskStatusValue.CANCELLED, TaskStatusValue.INTERRUPTED, TaskStatusValue.TIMEOUT): + self.abnormal_tasks += 1 + self.current_concurrent = max(0, self.current_concurrent - 1) + + self.total_tasks += 1 + + +@dataclass +class ResourceQuota: + """Resource quota of the machine for task execution.""" + max_cpu: float = 0.0 + max_memory: float = 0.0 + max_time: float = 0.0 + max_concurrent: int = 0 + + def is_available(self, + used_cpu: float = 0, + used_memory: float = 0, + used_time: float = 0, + current_concurrent: int = 0) -> bool: + """Check if resources are available.""" + checks = [] + if self.max_cpu > 0: + checks.append(used_cpu < self.max_cpu) + if self.max_memory > 0: + checks.append(used_memory < self.max_memory) + if self.max_time > 0: + checks.append(used_time < self.max_time) + if self.max_concurrent > 0: + checks.append(current_concurrent < self.max_concurrent) + + return all(checks) if checks else True From 3017ad6f14fc68b4aba795f745c58c07c142dc28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8A=B1=E5=B3=B0?= Date: Thu, 5 Mar 2026 16:04:16 +0800 Subject: [PATCH 02/10] task --- aworld/core/common.py | 8 ++++ aworld/runners/hook/task_hooks.py | 61 +++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 aworld/runners/hook/task_hooks.py diff --git a/aworld/core/common.py b/aworld/core/common.py index 6c35f8c79..75986d2e9 100644 --- a/aworld/core/common.py +++ b/aworld/core/common.py @@ -133,5 +133,13 @@ class TaskStatusValue: CANCELLED = 'cancelled' INTERRUPTED = 'interrupted' TIMEOUT = 'timeout' + DISABLED = 'disabled' + +class TaskTypeValue: + """Task type constants.""" + INSTANT = 'instant' + SCHEDULED = 'scheduled' + + TaskStatus = Literal['init', 'running', 'success', 'failed', 'cancelled', 'interrupted', 'timeout'] diff --git a/aworld/runners/hook/task_hooks.py b/aworld/runners/hook/task_hooks.py new file mode 100644 index 000000000..be413377f --- /dev/null +++ b/aworld/runners/hook/task_hooks.py @@ -0,0 +1,61 @@ +# coding: utf-8 +# Copyright (c) 2025 inclusionAI. +import abc + +from aworld.core.context.base import Context +from aworld.core.event.base import Message +from aworld.runners.hook.hook_factory import HookFactory +from aworld.runners.hook.hooks import PostLLMCallHook, PreLLMCallHook, Hook +from aworld.utils.common import convert_to_snake + + +@HookFactory.register(name="OnRunTaskProcessHook", + desc="OnRunTaskProcessHook") +class OnRunTaskProcessHook(Hook): + __metaclass__ = abc.ABCMeta + + def name(self): + return convert_to_snake("OnRunTaskProcessHook") + + async def exec(self, message: Message, context: Context = None) -> Message: + # get context + pass + + +@HookFactory.register(name="OnSuccessTaskProcessHook", + desc="OnSuccessTaskProcessHook") +class OnSuccessTaskProcessHook(Hook): + __metaclass__ = abc.ABCMeta + + def name(self): + return convert_to_snake("OnSuccessTaskProcessHook") + + async def exec(self, message: Message, context: Context = None) -> Message: + # get context + pass + + +@HookFactory.register(name="OnErrorTaskProcessHook", + desc="OnErrorTaskProcessHook") +class OnErrorTaskProcessHook(Hook): + __metaclass__ = abc.ABCMeta + + def name(self): + return convert_to_snake("OnErrorTaskProcessHook") + + async def exec(self, message: Message, context: Context = None) -> Message: + # get context + pass + + +@HookFactory.register(name="OnFinishTaskProcessHook", + desc="OnFinishTaskProcessHook") +class OnFinishTaskProcessHook(Hook): + __metaclass__ = abc.ABCMeta + + def name(self): + return convert_to_snake("OnFinishTaskProcessHook") + + async def exec(self, message: Message, context: Context = None) -> Message: + # get context + pass \ No newline at end of file From b543de42d0bc2ae8a912b1d85b1cc423cb82e4ce Mon Sep 17 00:00:00 2001 From: ganrunsheng Date: Thu, 5 Mar 2026 16:37:18 +0800 Subject: [PATCH 03/10] add task manager --- aworld/runners/hook/agent_hooks.py | 10 +- aworld/runners/hook/hooks.py | 14 +- aworld/runners/task_manager.py | 513 +++++++++++++++++++++++++++++ aworld/schedule/types.py | 8 +- 4 files changed, 524 insertions(+), 21 deletions(-) create mode 100644 aworld/runners/task_manager.py diff --git a/aworld/runners/hook/agent_hooks.py b/aworld/runners/hook/agent_hooks.py index fc695cffa..d813c2d7f 100644 --- a/aworld/runners/hook/agent_hooks.py +++ b/aworld/runners/hook/agent_hooks.py @@ -18,10 +18,6 @@ class PreLLMCallContextProcessHook(PreLLMCallHook): def name(self): return convert_to_snake("PreLLMCallContextProcessHook") - async def exec(self, message: Message, context: Context = None) -> Message: - # and do something - pass - @HookFactory.register(name="PostLLMCallContextProcessHook", desc="PostLLMCallContextProcessHook") @@ -32,17 +28,15 @@ class PostLLMCallContextProcessHook(PostLLMCallHook): def name(self): return convert_to_snake("PostLLMCallContextProcessHook") - async def exec(self, message: Message, context: Context = None) -> Message: - # get context - pass - @HookFactory.register(name="PostLLMTrajectoryHook", desc="PostLLMTrajectoryHook") class PostLLMTrajectoryHook(PostLLMCallHook): """Update trajectory after llm call.""" + def name(self): return convert_to_snake("PostLLMTrajectoryHook") + async def exec(self, message: Message, context: Context = None) -> Message: # get context agent_message = message.headers.get("agent_message") diff --git a/aworld/runners/hook/hooks.py b/aworld/runners/hook/hooks.py index bb4305c9c..c465fa4d0 100644 --- a/aworld/runners/hook/hooks.py +++ b/aworld/runners/hook/hooks.py @@ -19,6 +19,7 @@ class HookPoint: PRE_TASK_CALL = "pre_task_call" POST_TASK_CALL = "post_task_call" + class Hook: """Runner hook.""" __metaclass__ = abc.ABCMeta @@ -27,9 +28,9 @@ class Hook: def point(self): """Hook point.""" - @abc.abstractmethod async def exec(self, message: Message, context: Context = None) -> Message: """Execute hook function.""" + pass class StartHook(Hook): @@ -55,13 +56,15 @@ class ErrorHook(Hook): def point(self): return HookPoint.ERROR + class PreLLMCallHook(Hook): """Process in the hook point of the pre_llm_call.""" __metaclass__ = abc.ABCMeta def point(self): return HookPoint.PRE_LLM_CALL - + + class PostLLMCallHook(Hook): """Process in the hook point of the post_llm_call.""" __metaclass__ = abc.ABCMeta @@ -69,12 +72,6 @@ class PostLLMCallHook(Hook): def point(self): return HookPoint.POST_LLM_CALL -class PostToolCallHook(Hook): - """Process in the hook point of the post_tool_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.POST_TOOL_CALL class OutputProcessHook(Hook): """Output process hook for processing output data for display.""" @@ -123,4 +120,3 @@ class PostTaskCallHook(Hook): def point(self): return HookPoint.POST_TASK_CALL - diff --git a/aworld/runners/task_manager.py b/aworld/runners/task_manager.py new file mode 100644 index 000000000..507828b30 --- /dev/null +++ b/aworld/runners/task_manager.py @@ -0,0 +1,513 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import time +from typing import List, Optional, Set + +from aworld.core.storage.base import Storage +from aworld.core.storage.data import Data +from aworld.core.storage.inmemory_store import InmemoryStorage +from aworld.core.task import Task +from aworld.logs.util import logger + +from aworld.core.common import TaskStatusValue + + +class TaskManager: + """High-level task store providing simple API for task persistence. + + TaskStore acts as a facade over TaskStorage, providing: + - Simple CRUD operations + - Task querying and filtering + - Status-based retrieval + - Ready task detection + + Examples: + # Initialize with storage (required) + + manager = TaskManager(storage=InMemoryStorage()) + + # Add task + task = ScheduledTask(id="task1_id", name="My Task") + await manager.add_task(task) + + # Get task + task = await manager.get_task("task1_id") + + # List tasks + all_tasks = await manager.list() + pending_tasks = await manager.list(status=TaskStatusValue.INIT) + + # Get ready tasks + ready = await manager.get_ready() + """ + + def __init__(self, storage: Storage = InmemoryStorage()): + if storage is None: + raise ValueError("storage parameter is required") + + if not isinstance(storage, Storage): + raise TypeError(f"storage must be TaskStorage instance, got {type(storage)}") + + self.storage = storage + self._cache_completed: Set[str] = set() + + async def add_task(self, task: Task, overwrite: bool = True) -> bool: + """Add a task to the store. + + Args: + task: Task to add + overwrite: Whether to overwrite if task exists + + Returns: + bool: True if successful + + Raises: + ValueError: If task.id is empty + """ + if not task.id: + raise ValueError("Task ID cannot be empty") + + try: + success = await self.storage.create_data( + Data(value=task), + block_id=task.id, + overwrite=overwrite + ) + + if success: + logger.debug(f"Added task: {task.id}") + + return success + except Exception as e: + logger.error(f"Failed to add task {task.id}: {e}") + return False + + async def add_batch(self, tasks: List[Task], overwrite: bool = True) -> int: + """Add multiple tasks in batch. + + Args: + tasks: List of tasks to add + overwrite: Whether to overwrite existing tasks + + Returns: + int: Number of successfully added tasks + """ + success_count = 0 + for task in tasks: + if await self.add_task(task, overwrite=overwrite): + success_count += 1 + + logger.debug(f"Added {success_count}/{len(tasks)} tasks") + return success_count + + async def get_task(self, task_id: str) -> Optional[Task]: + """Get a task by ID. + + Args: + task_id: Task ID + + Returns: + Task or None if not found + """ + try: + tasks = await self.storage.get_data_items(block_id=task_id) + return tasks[0].value if tasks else None + except Exception as e: + logger.error(f"Failed to get task {task_id}: {e}") + return None + + async def update_task(self, task: Task) -> bool: + """Update an existing task. + + Args: + task: Updated task + + Returns: + bool: True if successful + """ + try: + success = await self.storage.update_data( + Data(value=task), + block_id=task.id, + exists=True + ) + + if success: + logger.debug(f"Updated task: {task.id}") + return success + except Exception as e: + logger.error(f"Failed to update task {task.id}: {e}") + return False + + async def update_status( + self, + task_id: str, + status: str, + **fields + ) -> bool: + """Update task status and optional fields. + + Args: + task_id: Task ID + status: New status + **fields: Additional fields to update (e.g., started_at, completed_at) + + Returns: + bool: True if successful + """ + task = await self.get_task(task_id) + if not task: + logger.warning(f"Task {task_id} not found") + return False + + task.task_status = status + + for field, value in fields.items(): + if hasattr(task, field): + setattr(task, field, value) + + return await self.update_task(task) + + async def delete_task(self, task_id: str) -> bool: + """Delete a task. + + Args: + task_id: Task ID + + Returns: + bool: True if successful + """ + try: + success = await self.storage.delete_data( + task_id, + block_id=task_id, + exists=False + ) + + if success: + logger.debug(f"Deleted task: {task_id}") + self._cache_completed.discard(task_id) + + return success + except Exception as e: + logger.error(f"Failed to delete task {task_id}: {e}") + return False + + async def exists(self, task_id: str) -> bool: + """Check if a task exists. + + Args: + task_id: Task ID + + Returns: + bool: True if task exists + """ + return await self.get_task(task_id) is not None + + async def list( + self, + status: Optional[str] = None, + limit: Optional[int] = None, + offset: int = 0 + ) -> List[Task]: + """List tasks with optional filtering. + + Args: + status: Filter by status (optional) + limit: Maximum number of tasks to return + offset: Number of tasks to skip + + Returns: + List of tasks + """ + try: + if status: + tasks = await self.get_tasks_by_status(status) + else: + tasks = await self.storage.get_data_items() + + # Apply offset and limit + tasks = tasks[offset:] if offset > 0 else tasks + tasks = tasks[:limit] if limit else tasks + return tasks + except Exception as e: + logger.error(f"Failed to list tasks: {e}") + return [] + + async def get_ready( + self, + current_time: Optional[float] = None, + limit: Optional[int] = None + ) -> List[Task]: + """Get tasks that are ready to execute. + + Args: + current_time: Current timestamp (default: now) + limit: Maximum number of tasks + + Returns: + List of ready tasks sorted by priority and time + """ + current_time = current_time if current_time else time.time() + + try: + ready_tasks = await self.get_ready_tasks(current_time, limit) + return ready_tasks + except Exception as e: + logger.error(f"Failed to get ready tasks: {e}") + return [] + + async def get_pending(self, limit: Optional[int] = None) -> List[Task]: + """Get all pending (INIT status) tasks.""" + return await self.list(status=TaskStatusValue.INIT, limit=limit) + + async def get_running(self, limit: Optional[int] = None) -> List[Task]: + """Get all running tasks.""" + return await self.list(status=TaskStatusValue.RUNNING, limit=limit) + + async def get_completed(self, limit: Optional[int] = None) -> List[Task]: + """Get all completed (SUCCESS status) tasks.""" + return await self.list(status=TaskStatusValue.SUCCESS, limit=limit) + + async def get_failed(self, limit: Optional[int] = None) -> List[Task]: + """Get all failed tasks.""" + return await self.list(status=TaskStatusValue.FAILED, limit=limit) + + async def get_periodic(self) -> List[Task]: + """Get all periodic tasks (with cron expression).""" + try: + return await self.get_periodic_tasks() + except Exception as e: + logger.error(f"Failed to get periodic tasks: {e}") + return [] + + async def count(self, status: Optional[str] = None) -> int: + """ + Count tasks, optionally filtered by status. + + Args: + status: Filter by status (optional) + + Returns: + Number of tasks + """ + try: + if status: + return await self.count_by_status(status) + else: + return await self.storage.size() + except Exception as e: + logger.error(f"Failed to count tasks: {e}") + return 0 + + async def clear(self) -> bool: + """ + Clear all tasks (use with caution). + + Returns: + bool: True if successful + """ + try: + await self.storage.delete_all() + self._cache_completed.clear() + logger.warning("Cleared all tasks from store") + return True + except Exception as e: + logger.error(f"Failed to clear tasks: {e}") + return False + + async def cleanup_completed( + self, + before_time: Optional[float] = None, + keep_periodic: bool = True + ) -> int: + """Clean up old completed tasks. + + Args: + before_time: Remove tasks completed before this time (default: 7 days ago) + keep_periodic: Keep periodic tasks even if completed + + Returns: + int: Number of tasks removed + """ + if before_time is None: + before_time = time.time() - (7 * 24 * 3600) + + try: + completed_tasks = await self.get_completed() + removed_count = 0 + + for task in completed_tasks: + # Check if task has completed_at attribute (for compatibility) + completed_at = getattr(task, 'completed_at', None) + is_periodic = getattr(task, 'is_periodic', False) + + # Determine if should remove + should_remove = False + if completed_at is not None and completed_at < before_time: + # Only remove if not keeping periodic tasks or task is not periodic + if not keep_periodic or not is_periodic: + should_remove = True + elif completed_at is None: + # For tasks without completed_at, use created_at or current time + created_at = getattr(task, 'created_at', None) + if created_at is not None and created_at < before_time: + if not keep_periodic or not is_periodic: + should_remove = True + + if should_remove: + if await self.delete_task(task.id): + removed_count += 1 + + logger.info(f"Cleaned up {removed_count} completed tasks") + return removed_count + except Exception as e: + logger.error(f"Failed to cleanup completed tasks: {e}") + return 0 + + async def get_statistics(self) -> dict: + """ + Get task statistics. + + Returns: + Dictionary with task counts by status + """ + try: + total = await self.count() + pending = await self.count(TaskStatusValue.INIT) + running = await self.count(TaskStatusValue.RUNNING) + completed = await self.count(TaskStatusValue.SUCCESS) + failed = await self.count(TaskStatusValue.FAILED) + + return { + "total": total, + "pending": pending, + "running": running, + "completed": completed, + "failed": failed, + } + except Exception as e: + logger.error(f"Failed to get statistics: {e}") + return {} + + def mark_completed(self, task_id: str): + """Mark a task as completed in cache (for dependency tracking).""" + self._cache_completed.add(task_id) + + def get_completed_ids(self) -> Set[str]: + """Get set of completed task IDs from cache.""" + return self._cache_completed.copy() + + async def get_tasks_by_status(self, status: str, limit: Optional[int] = None) -> List[Task]: + # Create condition for status filtering + from aworld.core.storage.condition import Condition + + try: + condition = Condition().add("task_status", "==", status) + tasks = await self.storage.select_data(condition) + + # Sort by next_run_time (earlier first), with None values at the end + if tasks and hasattr(tasks[0], "next_run_time"): + tasks.sort(key=lambda t: (t.next_run_time is None, t.next_run_time or 0)) + if limit: + tasks = tasks[:limit] + return tasks + except Exception as e: + logger.error(f"Failed to get tasks by status {status}: {e}") + return [] + + async def get_ready_tasks(self, + current_time: Optional[float] = None, + limit: Optional[int] = None) -> List[Task]: + """Get tasks that are ready to execute. + + A task is ready if: + 1. Status is INIT (pending) + 2. Current time >= next_run_time (or scheduled_time) + 3. All dependencies are completed + 4. Within start_time and end_time constraints + 5. Not exceeded max_executions + + Args: + current_time: Current timestamp (default: now) + limit: Maximum number of tasks to return + + Returns: + List of ready tasks sorted by priority (high to low) and next_run_time + """ + current_time = current_time if current_time else time.time() + try: + pending_tasks = await self.get_tasks_by_status(TaskStatusValue.INIT) + ready_tasks = [] + for task in pending_tasks: + # Check if task has is_ready method (for ScheduledTask compatibility) + if hasattr(task, 'is_ready') and callable(getattr(task, 'is_ready')): + # Use task's is_ready method + if task.is_ready(self._cache_completed, current_time): + ready_tasks.append(task) + else: + # For basic Task without is_ready, check basic readiness + # Check dependencies if task has them + dependencies = getattr(task, 'dependencies', []) + deps_ready = all(dep_id in self._cache_completed for dep_id in dependencies) + + # Check time constraints if task has next_run_time + next_run_time = getattr(task, 'next_run_time', None) + time_ready = True + if next_run_time is not None: + time_ready = current_time >= next_run_time + + if deps_ready and time_ready: + ready_tasks.append(task) + + # Sort by priority (higher first) and next_run_time (earlier first) + if ready_tasks: + def sort_key(t): + priority = getattr(t, "priority", 0) + next_run_time = getattr(t, "next_run_time", None) + return (-priority, next_run_time if next_run_time is not None else float('inf')) + + ready_tasks.sort(key=sort_key) + + ready_tasks = ready_tasks[:limit] if limit else ready_tasks + return ready_tasks + except Exception as e: + logger.error(f"Failed to get ready tasks: {e}") + return [] + + async def get_periodic_tasks(self) -> List[Task]: + """Get all periodic tasks (tasks with cron expression). + + Returns: + List of periodic tasks + """ + try: + # Get all tasks + all_tasks = await self.storage.get_data_items() + + # Filter periodic tasks + return [task for task in all_tasks if hasattr(task, 'is_periodic') and task.is_periodic] + except Exception as e: + logger.error(f"Failed to get periodic tasks: {e}") + return [] + + async def count_by_status(self, status: str) -> int: + """Count tasks by status. + + Args: + status: Task status to count + + Returns: + Number of tasks with the specified status + """ + try: + from aworld.core.storage.condition import Condition + + # Create condition for status filtering + condition = Condition().add("task_status", "==", status) + count = await self.storage.size(condition) + + return count + except Exception as e: + logger.error(f"Failed to count tasks by status {status}: {e}") + return 0 diff --git a/aworld/schedule/types.py b/aworld/schedule/types.py index 77d5cceb5..95fe685e6 100644 --- a/aworld/schedule/types.py +++ b/aworld/schedule/types.py @@ -9,8 +9,8 @@ @dataclass -class InstantTask(Task): - """Instant task with dependencies and metadata.""" +class SchedulableTask(Task): + """Instant schedulable task with dependencies and priority.""" priority: int = 0 retry_count: int = 0 dependencies: List[str] = field(default_factory=list) @@ -27,7 +27,7 @@ class InstantTask(Task): metadata: Dict[str, Any] = field(default_factory=dict) def __lt__(self, other): - if not isinstance(other, InstantTask): + if not isinstance(other, SchedulableTask): return NotImplemented return self.priority < other.priority @@ -51,7 +51,7 @@ def duration(self) -> float: @dataclass -class ScheduledTask(InstantTask): +class ScheduledTask(SchedulableTask): """Unified scheduled task supporting cron-like scheduling, one-time execution, and delayed execution. Scheduling modes (priority order): From 1afbfab7734d3b55e58d0fef29eb4b3a5aeab30a Mon Sep 17 00:00:00 2001 From: ganrunsheng Date: Thu, 5 Mar 2026 17:39:08 +0800 Subject: [PATCH 04/10] add task scheduler beta --- aworld/schedule/scheduler.py | 496 +++++++++++++++++++++++++++++++++++ 1 file changed, 496 insertions(+) create mode 100644 aworld/schedule/scheduler.py diff --git a/aworld/schedule/scheduler.py b/aworld/schedule/scheduler.py new file mode 100644 index 000000000..5d1a84e64 --- /dev/null +++ b/aworld/schedule/scheduler.py @@ -0,0 +1,496 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import asyncio +import time +from typing import Optional, List, Callable, Awaitable, Dict, Set + +from aworld.logs.util import logger +from aworld.runners.hook.utils import run_hooks +from aworld.runners.task_manager import TaskManager +from aworld.schedule.strategy import ScheduleStrategy, create_strategy +from aworld.schedule.types import ScheduledTask, ResourceQuota, ScheduledTaskStatistics +from aworld.core.common import TaskStatusValue +from aworld.core.task import Task, TaskResponse + + +class TaskScheduler: + """ + Task Scheduler for managing and executing scheduled tasks. + + The scheduler supports: + - Multiple scheduling strategies (FIFO, Priority, DAG, Auto) + - Resource quota management + - Concurrent task execution + - Periodic task scheduling + - Task dependency resolution + - Lifecycle hooks + + Examples: + # Create scheduler with storage + from aworld.core.storage.inmemory_store import InmemoryStorage + storage = InmemoryStorage() + manager = TaskManager(storage=storage) + scheduler = TaskScheduler(task_manager=manager) + + # Add tasks + task1 = ScheduledTask(id="task1", name="Task 1", priority=10) + await scheduler.add_task(task1) + + # Set execution handler + async def execute_handler(task): + print(f"Executing {task.name}") + await asyncio.sleep(1) + return True + + scheduler.set_executor(execute_handler) + + # Run scheduler + await scheduler.run(max_iterations=10) + + # Or run scheduler in background + scheduler.start() + # ... do other work ... + await scheduler.stop() + """ + + def __init__( + self, + task_manager: TaskManager, + strategy: Optional[ScheduleStrategy] = None, + strategy_type: str = 'auto', + resource_quota: Optional[ResourceQuota] = None, + max_concurrent: int = 10, + poll_interval: float = 1.0, + enable_periodic: bool = True + ): + """ + Initialize TaskScheduler. + + Args: + task_manager: TaskManager instance (required) + strategy: Custom scheduling strategy (optional) + strategy_type: Type of strategy if strategy is None ('auto', 'fifo', 'priority', 'dag') + resource_quota: Resource quota for task execution + max_concurrent: Maximum concurrent tasks + poll_interval: Seconds between scheduling cycles + enable_periodic: Enable periodic task scheduling + """ + if task_manager is None: + raise ValueError("task_manager is required") + + self.task_manager = task_manager + self.strategy = strategy or create_strategy(strategy_type) + self.resource_quota = resource_quota or ResourceQuota(max_concurrent=max_concurrent) + + self.poll_interval = poll_interval + self.enable_periodic = enable_periodic + + # Execution state + self._executor: Optional[Callable[[Task], Awaitable[bool]]] = None + self._running = False + self._scheduler_task: Optional[asyncio.Task] = None + self._active_tasks: Dict[str, asyncio.Task] = {} + self._stop_event = asyncio.Event() + + # Statistics + self.statistics = ScheduledTaskStatistics() + + # Lifecycle hooks + self._on_task_start: Optional[Callable[[Task], Awaitable[None]]] = None + self._on_task_complete: Optional[Callable[[Task, bool], Awaitable[None]]] = None + self._on_task_error: Optional[Callable[[Task, Exception], Awaitable[None]]] = None + + def set_executor(self, executor: Callable[[Task], Awaitable[bool]]): + """ + Set the task executor function. + + Args: + executor: Async function that takes a Task and returns bool (success) + + Examples: + async def my_executor(task): + print(f"Executing {task.name}") + # ... do work ... + return True + + scheduler.set_executor(my_executor) + """ + self._executor = executor + + async def add_task(self, task: ScheduledTask, overwrite: bool = True) -> bool: + """ + Add a task to the scheduler. + + Args: + task: Task to add + overwrite: Whether to overwrite existing task + + Returns: + bool: True if successful + """ + success = await self.task_manager.add_task(task, overwrite=overwrite) + + if success: + logger.info(f"Added task to scheduler: {task.id} ({task.name})") + + return success + + async def add_tasks(self, tasks: List[ScheduledTask], overwrite: bool = True) -> int: + """ + Add multiple tasks to the scheduler. + + Args: + tasks: List of tasks to add + overwrite: Whether to overwrite existing tasks + + Returns: + int: Number of successfully added tasks + """ + return await self.task_manager.add_batch(tasks, overwrite=overwrite) + + async def schedule(self, tasks: Optional[List[ScheduledTask]] = None) -> List[List[ScheduledTask]]: + """Schedule tasks into execution batches using the configured strategy. + + Args: + tasks: Tasks to schedule (default: get ready tasks from manager) + + Returns: + List of task batches, where each batch can be executed in parallel + """ + if tasks is None: + tasks = await self.task_manager.get_ready() + + if not tasks: + return [] + + # Apply scheduling strategy + batches = await self.strategy.schedule(tasks) + logger.debug(f"Scheduled {len(tasks)} tasks into {len(batches)} batches") + return batches + + async def execute(self, task: Task) -> TaskResponse: + """Execute a single task. + + Args: + task: Task to execute + + Returns: + bool: True if successful + """ + if self._executor is None: + logger.error("No executor set. Use set_executor() to configure task execution.") + return None + + # demo + await run_hooks(task.context, "task_start", "scheduler", payload=task) + try: + # Mark task as started + await self.task_manager.update_status( + task.id, + TaskStatusValue.RUNNING, + started_at=time.time() + ) + + logger.info(f"Executing task: {task.id} ({task.name})") + + # Execute task + success = await self._executor(task) + + # Mark task as completed + await self.task_manager.update_status( + task.id, + TaskStatusValue.SUCCESS if success else TaskStatusValue.FAILED, + completed_at=time.time() + ) + + # Update completed cache for dependency tracking + if success: + self.task_manager.mark_completed(task.id) + + logger.info(f"Task {'completed' if success else 'failed'}: {task.id}") + + # Handle periodic tasks + if success and hasattr(task, 'is_periodic') and task.is_periodic: + await self._reschedule_periodic_task(task) + + return None + except Exception as e: + logger.error(f"Error executing task {task.id}: {e}") + + # Mark task as failed + await self.task_manager.update_status( + task.id, + TaskStatusValue.FAILED, + completed_at=time.time() + ) + + return None + + async def _reschedule_periodic_task(self, task: ScheduledTask): + """Reschedule a periodic task for next execution.""" + try: + # Update next run time + task.update_next_run_time() + + # Check if should continue + if task.max_executions and task.execution_count >= task.max_executions: + logger.info(f"Periodic task {task.id} reached max executions") + return + + if task.end_time and time.time() > task.end_time: + logger.info(f"Periodic task {task.id} reached end time") + return + + # Reset for next execution + task.task_status = "init" + task.started_at = None + task.completed_at = None + + # Update in storage + await self.task_manager.update_task(task) + + logger.info(f"Rescheduled periodic task {task.id} for {task.next_run_time}") + + except Exception as e: + logger.error(f"Failed to reschedule periodic task {task.id}: {e}") + + async def execute_batch(self, batch: List[Task]) -> List[bool]: + """ + Execute a batch of tasks concurrently. + + Args: + batch: Batch of tasks to execute + + Returns: + List of success flags for each task + """ + if not batch: + return [] + + # Check resource availability + current_concurrent = len(self._active_tasks) + if self.resource_quota.max_concurrent > 0: + available_slots = self.resource_quota.max_concurrent - current_concurrent + if available_slots <= 0: + logger.warning("Max concurrent tasks reached, skipping batch") + return [False] * len(batch) + + # Limit batch size to available slots + batch = batch[:available_slots] + + # Execute tasks concurrently + tasks = [self.execute(task) for task in batch] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Convert exceptions to False + return [r if isinstance(r, bool) else False for r in results] + + async def run_once(self) -> int: + """ + Run one scheduling cycle. + + Returns: + int: Number of tasks executed + """ + # Get ready tasks + ready_tasks = await self.task_manager.get_ready() + + if not ready_tasks: + return 0 + + # Schedule tasks into batches + batches = await self.schedule(ready_tasks) + + executed_count = 0 + + # Execute batches + for batch in batches: + results = await self.execute_batch(batch) + executed_count += sum(1 for r in results if r) + + return executed_count + + async def run(self, timeout: Optional[float] = None): + """ + Run the scheduler for a limited time or iterations. + + Args: + max_iterations: Maximum number of scheduling cycles (None = unlimited) + timeout: Maximum time to run in seconds (None = unlimited) + + Examples: + # Run for 10 iterations + await scheduler.run(max_iterations=10) + + # Run for 60 seconds + await scheduler.run(timeout=60) + + # Run until all tasks complete + await scheduler.run() + """ + start_time = time.time() + logger.info("Starting scheduler") + + try: + # Run one cycle + executed = await self.run_once() + + # Check if no tasks to execute + if executed == 0: + pending = await self.task_manager.count(TaskStatusValue.INIT) + if pending == 0: + logger.info("No more tasks to execute") + except Exception as e: + logger.error(f"Error in scheduler: {e}") + raise + finally: + logger.info(f"Scheduler completed, time cost {time.time() - start_time}(s)") + + async def _scheduler_loop(self): + """Main scheduler loop (for background execution).""" + logger.info("Scheduler loop started") + + try: + while self._running: + # Check stop event with timeout + try: + await asyncio.wait_for( + self._stop_event.wait(), + timeout=self.poll_interval + ) + # Stop event was set + break + except asyncio.TimeoutError: + # Timeout - continue with scheduling cycle + pass + + # Run one scheduling cycle + await self.run_once() + + except Exception as e: + logger.error(f"Error in scheduler loop: {e}") + finally: + self._running = False + logger.info("Scheduler loop stopped") + + def start(self): + """ + Start the scheduler in background. + + Examples: + scheduler.start() + # ... scheduler runs in background ... + await scheduler.stop() + """ + if self._running: + logger.warning("Scheduler already running") + return + + self._running = True + self._stop_event.clear() + self._scheduler_task = asyncio.create_task(self._scheduler_loop()) + + logger.info("Scheduler started in background") + + async def stop(self, wait: bool = True): + """ + Stop the background scheduler. + + Args: + wait: Wait for scheduler to complete current cycle + """ + if not self._running: + logger.warning("Scheduler not running") + return + + logger.info("Stopping scheduler...") + + self._running = False + self._stop_event.set() + + if wait and self._scheduler_task: + await self._scheduler_task + + # Cancel active tasks + for task_id, task in self._active_tasks.items(): + if not task.done(): + task.cancel() + logger.debug(f"Cancelled active task: {task_id}") + + self._active_tasks.clear() + + logger.info("Scheduler stopped") + + @property + def is_running(self) -> bool: + """Check if scheduler is running.""" + return self._running + + async def get_statistics(self) -> dict: + """ + Get scheduler statistics. + + Returns: + Dictionary with scheduler statistics + """ + stats = await self.task_manager.get_statistics() + + stats.update({ + "is_running": self.is_running, + "active_tasks": len(self._active_tasks), + "max_concurrent": self.resource_quota.max_concurrent, + "strategy": self.strategy.__class__.__name__, + "poll_interval": self.poll_interval, + }) + + return stats + + async def pause(self): + """Pause the scheduler (can be resumed).""" + if not self._running: + logger.warning("Scheduler not running") + return + + self._stop_event.set() + logger.info("Scheduler paused") + + async def resume(self): + """Resume the paused scheduler.""" + if self._running and self._stop_event.is_set(): + self._stop_event.clear() + logger.info("Scheduler resumed") + else: + logger.warning("Scheduler not paused or not running") + + async def cancel_task(self, task_id: str) -> bool: + """ + Cancel a scheduled task. + + Args: + task_id: ID of task to cancel + + Returns: + bool: True if successful + """ + # Cancel active execution + if task_id in self._active_tasks: + self._active_tasks[task_id].cancel() + del self._active_tasks[task_id] + + # Update task status + return await self.task_manager.update_status( + task_id, + TaskStatusValue.CANCELLED, + completed_at=time.time() + ) + + async def cleanup(self, before_time: Optional[float] = None): + """ + Clean up old completed tasks. + + Args: + before_time: Remove tasks completed before this time + """ + removed = await self.task_manager.cleanup_completed(before_time=before_time) + logger.info(f"Cleaned up {removed} old tasks") + return removed \ No newline at end of file From c85b207fc3e39b238593f5a1aebac72f4f1897db Mon Sep 17 00:00:00 2001 From: ganrunsheng Date: Thu, 5 Mar 2026 21:02:32 +0800 Subject: [PATCH 05/10] add base hooks; add executor --- aworld/core/common.py | 11 ++- aworld/core/context/amni/__init__.py | 1 - aworld/core/context/amni/contexts.py | 2 +- aworld/core/context/amni/state/task_state.py | 2 +- aworld/core/context/base.py | 6 +- aworld/core/event/message_future.py | 6 +- aworld/core/task.py | 6 +- aworld/events/util.py | 16 ++-- aworld/runners/event_runner.py | 12 +-- aworld/runners/handler/background_task.py | 19 ++--- aworld/runners/handler/task.py | 8 +- aworld/runners/hook/agent_hooks.py | 12 +-- aworld/runners/hook/hook_factory.py | 2 +- aworld/runners/hook/hooks.py | 85 +++++++++----------- aworld/runners/hook/task_hooks.py | 65 ++++----------- aworld/runners/hook/tool_hooks.py | 28 +++++++ aworld/runners/task_manager.py | 20 ++--- aworld/schedule/executor.py | 42 ++++++++++ aworld/schedule/scheduler.py | 42 +++++----- aworld/schedule/task_graph.py | 12 +-- aworld/schedule/types.py | 12 +-- 21 files changed, 208 insertions(+), 201 deletions(-) create mode 100644 aworld/runners/hook/tool_hooks.py create mode 100644 aworld/schedule/executor.py diff --git a/aworld/core/common.py b/aworld/core/common.py index 75986d2e9..5b442cd4c 100644 --- a/aworld/core/common.py +++ b/aworld/core/common.py @@ -6,7 +6,6 @@ from typing import Dict, Any, Optional, Union, List, Literal from enum import Enum - from aworld.config import ConfigDict Config = Union[Dict[str, Any], ConfigDict, BaseModel] @@ -96,15 +95,18 @@ class TaskItem(BaseModel): params: Optional[Dict[str, Any]] = {} policy_info: Optional[Any] = None + class CallbackItem(BaseModel): data: Any node_id: str = None actions: List[ActionModel] = [] + class CallbackActionType(str, Enum): BYPASS = "bypass" OVERRIDE = "override" + class CallbackResult(BaseModel): success: bool = False result_data: Any = None @@ -124,7 +126,7 @@ class StreamingMode(enum.Enum): ALL = 'all' -class TaskStatusValue: +class TaskStatus(str): """Task status constants.""" INIT = 'init' RUNNING = 'running' @@ -135,11 +137,8 @@ class TaskStatusValue: TIMEOUT = 'timeout' DISABLED = 'disabled' + class TaskTypeValue: """Task type constants.""" INSTANT = 'instant' SCHEDULED = 'scheduled' - - - -TaskStatus = Literal['init', 'running', 'success', 'failed', 'cancelled', 'interrupted', 'timeout'] diff --git a/aworld/core/context/amni/__init__.py b/aworld/core/context/amni/__init__.py index 9bb6614e1..f00f537b0 100644 --- a/aworld/core/context/amni/__init__.py +++ b/aworld/core/context/amni/__init__.py @@ -9,7 +9,6 @@ from aworld import trace from aworld.config import AgentConfig, AgentMemoryConfig -from aworld.core.common import TaskStatus # lazy import from aworld.core.context.base import Context from aworld.dataset.types import TrajectoryItem diff --git a/aworld/core/context/amni/contexts.py b/aworld/core/context/amni/contexts.py index ebd4c5320..d70b03429 100644 --- a/aworld/core/context/amni/contexts.py +++ b/aworld/core/context/amni/contexts.py @@ -20,7 +20,7 @@ from .state import TaskInput, Summary from .utils import jsonplus from .worksapces import workspace_repo -from ...task import TaskStatusValue +from ...task import TaskStatus class ContextManager(BaseModel): diff --git a/aworld/core/context/amni/state/task_state.py b/aworld/core/context/amni/state/task_state.py index 6166d6175..ec2c03162 100644 --- a/aworld/core/context/amni/state/task_state.py +++ b/aworld/core/context/amni/state/task_state.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from pydantic import field_validator -from aworld.core.task import TaskStatus +from aworld.core.common import TaskStatus from .agent_state import ApplicationAgentState from .common import WorkingState, TaskInput, TaskOutput from ..utils.modelplus import from_dict_to_memory_message diff --git a/aworld/core/context/base.py b/aworld/core/context/base.py index 1d9257ebb..79deacb2a 100644 --- a/aworld/core/context/base.py +++ b/aworld/core/context/base.py @@ -15,7 +15,7 @@ from aworld.utils.common import nest_dict_counter if TYPE_CHECKING: - from aworld.core.task import Task, TaskResponse, TaskStatus, TaskStatusValue + from aworld.core.task import Task, TaskResponse, TaskStatus from aworld.events.manager import EventManager from aworld.core.agent import BaseAgent from aworld.core.context.amni import AgentContextConfig @@ -783,8 +783,8 @@ async def snapshot(self): return checkpoint async def get_task_status(self): - from aworld.core.common import TaskStatusValue - return TaskStatusValue.SUCCESS + from aworld.core.common import TaskStatus + return TaskStatus.SUCCESS async def update_task_status(self, task_id: str, status: 'TaskStatus'): pass diff --git a/aworld/core/event/message_future.py b/aworld/core/event/message_future.py index 730562412..44ff0ccee 100644 --- a/aworld/core/event/message_future.py +++ b/aworld/core/event/message_future.py @@ -84,10 +84,10 @@ async def wait(self, timeout: float = 30.0, context: 'Context' = None): from aworld.logs.util import logger logger.info(f"Waiting for message {self.msg_id}") if context: - from aworld.core.common import TaskStatusValue + from aworld.core.common import TaskStatus task_status = await context.get_task_status() - if (task_status == TaskStatusValue.CANCELLED - or task_status == TaskStatusValue.INTERRUPTED): + if (task_status == TaskStatus.CANCELLED + or task_status == TaskStatus.INTERRUPTED): self.set_empty_result(msg=f"Task {task_status.lower()}: message not sent") logger.info(f"Task {task_status.lower()}: message not sent") return self.result() diff --git a/aworld/core/task.py b/aworld/core/task.py index 27ad50474..f4f0bea35 100644 --- a/aworld/core/task.py +++ b/aworld/core/task.py @@ -12,7 +12,7 @@ from aworld.agents.llm_agent import Agent from aworld.core.agent.swarm import Swarm -from aworld.core.common import Config, Observation, StreamingMode, TaskStatus, TaskStatusValue +from aworld.core.common import Config, Observation, StreamingMode, TaskStatus from aworld.core.context.base import Context from aworld.core.tool.base import Tool, AsyncTool from aworld.output.outputs import Outputs, DefaultOutputs @@ -58,7 +58,7 @@ class Task: max_retry_count: int = field(default=0) timeout: int = field(default=0) observation: Optional[Observation] = field(default=None) - task_status: TaskStatus = field(default=TaskStatusValue.INIT) + task_status: TaskStatus = field(default=TaskStatus.INIT) # streaming support streaming_mode: StreamingMode = field(default=None) @@ -111,7 +111,7 @@ class TaskResponse: msg: str | None = field(default=None) trajectory: List[Dict[str, Any]]= field(default_factory=list) # task final status, e.g. success/failed/cancelled - status: TaskStatus | None = field(default=TaskStatusValue.SUCCESS) + status: TaskStatus | None = field(default=TaskStatus.SUCCESS) def to_dict(self) -> Dict[str, Any]: return { diff --git a/aworld/events/util.py b/aworld/events/util.py index 42c6f18d7..fd04d3b77 100644 --- a/aworld/events/util.py +++ b/aworld/events/util.py @@ -3,7 +3,7 @@ from typing import Callable, Any, List import asyncio -from aworld.core.common import TaskStatusValue +from aworld.core.common import TaskStatus from aworld.core.context.base import Context from aworld.events import eventbus from aworld.core.event.base import Message, Constants @@ -56,9 +56,9 @@ async def send_message(msg: Message): """ context = msg.context if context: - from aworld.core.common import TaskStatusValue + from aworld.core.common import TaskStatus task_status = await context.get_task_status() - if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED: + if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED: await _send_finish_message(msg, task_status) return await _send_message(msg) @@ -78,9 +78,9 @@ async def send_and_wait_message(msg: Message) -> List['HandleResult'] | None: """ context = msg.context if context: - from aworld.core.common import TaskStatusValue + from aworld.core.common import TaskStatus task_status = await context.get_task_status() - if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED: + if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED: await _send_finish_message(msg, task_status) return None await _send_message(msg) @@ -184,9 +184,9 @@ async def send_message_with_future(msg: Message) -> MessageFuture: if context: - from aworld.core.common import TaskStatusValue + from aworld.core.common import TaskStatus task_status = await context.get_task_status() - if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED: + if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED: await _send_finish_message(msg, task_status) # Task cancelled or interrupted, return a completed Future with empty result dummy_msg_id = f"cancelled_{msg.id}" @@ -199,7 +199,7 @@ async def send_message_with_future(msg: Message) -> MessageFuture: future = MessageFuture(msg_id) return future -async def _send_finish_message(msg: Message, status: str = TaskStatusValue.SUCCESS): +async def _send_finish_message(msg: Message, status: str = TaskStatus.SUCCESS): context = msg.context await _send_message(Message(payload=f"Task {status.lower()}",session_id=context.session_id, category=Constants.TASK, headers={"context": context})) diff --git a/aworld/runners/event_runner.py b/aworld/runners/event_runner.py index daab94bcc..29e17dee0 100644 --- a/aworld/runners/event_runner.py +++ b/aworld/runners/event_runner.py @@ -15,7 +15,7 @@ from aworld.dataset.trajectory_storage import get_storage_instance from aworld.core.event.base import Message, Constants, TopicType, ToolMessage, AgentMessage from aworld.core.exceptions import AWorldRuntimeException -from aworld.core.task import Task, TaskResponse, TaskStatusValue +from aworld.core.task import Task, TaskResponse, TaskStatus from aworld.dataset.trajectory_dataset import TrajectoryDataset from aworld.events.manager import EventManager from aworld.logs.util import logger, trajectory_logger @@ -349,7 +349,7 @@ async def _do_run(self): time_cost=( time.time() - start), usage=self.context.token_usage, - status=TaskStatusValue.SUCCESS if not msg else TaskStatusValue.FAILED) + status=TaskStatus.SUCCESS if not msg else TaskStatus.FAILED) break logger.debug(f"{task_flag} task {self.task.id} next message snap") # consume message @@ -424,7 +424,7 @@ def _response(self): self._task_response = TaskResponse(id=self.context.task_id if self.context else "", success=False, msg="Task return None.", - status=TaskStatusValue.FAILED) + status=TaskStatus.FAILED) if self.context.get_task().conf and self.context.get_task().conf.resp_carry_raw_llm_resp == True: self._task_response.raw_llm_resp = self.context.context_info.get('llm_output') self._task_response.trace_id = get_trace_id() @@ -466,14 +466,14 @@ async def should_stop_task(self, message: Message): time_cost=(time.time() - self.start_time), usage=self.context.token_usage, msg=f'Task timeout after {time_cost} seconds.', - status=TaskStatusValue.TIMEOUT + status=TaskStatus.TIMEOUT ) - await self.context.update_task_status(self.task.id, TaskStatusValue.TIMEOUT) + await self.context.update_task_status(self.task.id, TaskStatus.TIMEOUT) return True # Check Task status from context task_status = await self.context.get_task_status() - if task_status == TaskStatusValue.INTERRUPTED or task_status == TaskStatusValue.CANCELLED: + if task_status == TaskStatus.INTERRUPTED or task_status == TaskStatus.CANCELLED: logger.warn(f"{task_flag} task {self.task.id} is {task_status}.") self._task_response = TaskResponse( answer='', diff --git a/aworld/runners/handler/background_task.py b/aworld/runners/handler/background_task.py index 6ded0004c..839d3fe4f 100644 --- a/aworld/runners/handler/background_task.py +++ b/aworld/runners/handler/background_task.py @@ -1,25 +1,18 @@ # coding: utf-8 # Copyright (c) 2025 inclusionAI. import abc -import asyncio -import time from typing import AsyncGenerator, TYPE_CHECKING, Tuple -from env_channel import EnvChannelMessage - from aworld.core.common import TaskItem, Observation -from aworld.core.context.amni import get_context_manager, ContextManager, ApplicationContext, AmniContext from aworld.core.context.base import Context from aworld.core.event.base import Message, Constants, TopicType, BackgroundTaskMessage, AgentMessage -from aworld.core.task import TaskResponse, TaskStatusValue, Task, Runner +from aworld.core.task import TaskResponse, TaskStatus, Task, Runner from aworld.events.util import send_message from aworld.logs.util import logger from aworld.memory.main import MemoryFactory from aworld.memory.models import MemoryHumanMessage, MessageMetadata -from aworld.runner import Runners from aworld.runners import HandlerFactory from aworld.runners.handler.base import DefaultHandler -from aworld.runners.hook.hooks import HookPoint if TYPE_CHECKING: from aworld.runners.event_runner import TaskEventRunner @@ -166,11 +159,11 @@ async def _merge_by_topic(self, message: Message): content = data.content elif isinstance(data, Observation): content = data.content - elif isinstance(data, EnvChannelMessage): - data = data.message - if not agent_id: - agent_id = data.get('env_content', {}).get('agent_id') - content = data + # elif isinstance(data, EnvChannelMessage): + # data = data.message + # if not agent_id: + # agent_id = data.get('env_content', {}).get('agent_id') + # content = data elif isinstance(data, dict): if not agent_id: agent_id = data.get('env_content', {}).get('agent_id') diff --git a/aworld/runners/handler/task.py b/aworld/runners/handler/task.py index 5e99f949b..42bd12572 100644 --- a/aworld/runners/handler/task.py +++ b/aworld/runners/handler/task.py @@ -8,7 +8,7 @@ from aworld.core.common import TaskItem from aworld.core.event.base import Message, Constants, TopicType -from aworld.core.task import TaskResponse, TaskStatusValue +from aworld.core.task import TaskResponse, TaskStatus from aworld.core.tool.base import Tool, AsyncTool from aworld.logs.util import logger, trajectory_logger from aworld.output import Output @@ -88,7 +88,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: id=self.runner.task.id, time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage, - status=TaskStatusValue.FAILED) + status=TaskStatus.FAILED) await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, @@ -149,7 +149,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage, msg=f'cancellation message received: {task_item.msg}', - status=TaskStatusValue.CANCELLED) + status=TaskStatus.CANCELLED) await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE) @@ -165,7 +165,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: time_cost=(time.time() - self.runner.start_time), usage=self.runner.context.token_usage, msg=f'interruption message received: {task_item.msg}', - status=TaskStatusValue.INTERRUPTED) + status=TaskStatus.INTERRUPTED) await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE) diff --git a/aworld/runners/hook/agent_hooks.py b/aworld/runners/hook/agent_hooks.py index d813c2d7f..21eef232e 100644 --- a/aworld/runners/hook/agent_hooks.py +++ b/aworld/runners/hook/agent_hooks.py @@ -1,19 +1,16 @@ # coding: utf-8 # Copyright (c) 2025 inclusionAI. -import abc - from aworld.core.context.base import Context from aworld.core.event.base import Message from aworld.runners.hook.hook_factory import HookFactory -from aworld.runners.hook.hooks import PostLLMCallHook, PreLLMCallHook +from aworld.runners.hook.hooks import OnStartLLMCallHook, OnFinishedLLMCallHook from aworld.utils.common import convert_to_snake @HookFactory.register(name="PreLLMCallContextProcessHook", desc="PreLLMCallContextProcessHook") -class PreLLMCallContextProcessHook(PreLLMCallHook): +class PreLLMCallContextProcessHook(OnStartLLMCallHook): """Process in the hook point of the pre_llm_call.""" - __metaclass__ = abc.ABCMeta def name(self): return convert_to_snake("PreLLMCallContextProcessHook") @@ -21,9 +18,8 @@ def name(self): @HookFactory.register(name="PostLLMCallContextProcessHook", desc="PostLLMCallContextProcessHook") -class PostLLMCallContextProcessHook(PostLLMCallHook): +class PostLLMCallContextProcessHook(OnFinishedLLMCallHook): """Process in the hook point of the post_llm_call.""" - __metaclass__ = abc.ABCMeta def name(self): return convert_to_snake("PostLLMCallContextProcessHook") @@ -31,7 +27,7 @@ def name(self): @HookFactory.register(name="PostLLMTrajectoryHook", desc="PostLLMTrajectoryHook") -class PostLLMTrajectoryHook(PostLLMCallHook): +class PostLLMTrajectoryHook(OnFinishedLLMCallHook): """Update trajectory after llm call.""" def name(self): diff --git a/aworld/runners/hook/hook_factory.py b/aworld/runners/hook/hook_factory.py index b0f7b94a3..c3579b215 100644 --- a/aworld/runners/hook/hook_factory.py +++ b/aworld/runners/hook/hook_factory.py @@ -5,7 +5,7 @@ from aworld.core.factory import Factory from aworld.logs.util import logger -from aworld.runners.hook.hooks import Hook, StartHook, HookPoint +from aworld.runners.hook.hooks import Hook, HookPoint class HookManager(Factory): diff --git a/aworld/runners/hook/hooks.py b/aworld/runners/hook/hooks.py index c465fa4d0..b9a961c32 100644 --- a/aworld/runners/hook/hooks.py +++ b/aworld/runners/hook/hooks.py @@ -4,20 +4,28 @@ from aworld.core.context.base import Context from aworld.core.event.base import Message -from aworld.models.model_response import ModelResponse class HookPoint: START = "start" FINISHED = "finished" ERROR = "error" - PRE_LLM_CALL = "pre_llm_call" - POST_LLM_CALL = "post_llm_call" OUTPUT_PROCESS = "output_process" - PRE_TOOL_CALL = "pre_tool_call" - POST_TOOL_CALL = "post_tool_call" - PRE_TASK_CALL = "pre_task_call" - POST_TASK_CALL = "post_task_call" + ON_START_LLM_CALL = "on_start_llm_call" + ON_FINISHED_LLM_CALL = "on_finished_llm_call" + ON_LLM_CALL = "on_llm_call" + ON_SUCCESS_LLM_CALL = "on_success_llm_call" + ON_ERROR_LLM_CALL = "on_error_llm_call" + ON_START_TOOL_CALL = "on_start_tool_call" + ON_TOOL_CALL = "on_tool_call" + ON_FINISHED_TOOL_CALL = "on_finished_tool_call" + ON_SUCCESS_TOOL_CALL = "on_success_tool_call" + ON_ERROR_TOOL_CALL = "on_error_tool_call" + ON_RUN_TASK = "on_run_task" + ON_SUCCESS_TASK = "on_success_task" + ON_ERROR_TASK = "on_error_task" + ON_START_TASK = "on_start_task" + ON_FINISHED_TASK = "on_finished_task" class Hook: @@ -28,6 +36,10 @@ class Hook: def point(self): """Hook point.""" + def name(self): + """Hook name.""" + return self.__class__.__name__ + async def exec(self, message: Message, context: Context = None) -> Message: """Execute hook function.""" pass @@ -35,7 +47,6 @@ async def exec(self, message: Message, context: Context = None) -> Message: class StartHook(Hook): """Process in the hook point of the start.""" - __metaclass__ = abc.ABCMeta def point(self): return HookPoint.START @@ -43,7 +54,6 @@ def point(self): class FinishedHook(Hook): """Process in the hook point of the finished.""" - __metaclass__ = abc.ABCMeta def point(self): return HookPoint.FINISHED @@ -51,31 +61,38 @@ def point(self): class ErrorHook(Hook): """Process in the hook point of the error.""" - __metaclass__ = abc.ABCMeta def point(self): return HookPoint.ERROR -class PreLLMCallHook(Hook): - """Process in the hook point of the pre_llm_call.""" - __metaclass__ = abc.ABCMeta +class OnStartLLMCallHook(Hook): + def point(self): + return HookPoint.ON_START_LLM_CALL + +class OnFinishedLLMCallHook(Hook): def point(self): - return HookPoint.PRE_LLM_CALL + return HookPoint.ON_FINISHED_LLM_CALL -class PostLLMCallHook(Hook): - """Process in the hook point of the post_llm_call.""" - __metaclass__ = abc.ABCMeta +class OnLLMCallHook(Hook): + def point(self): + return HookPoint.ON_LLM_CALL + + +class OnSuccessLLMCallHook(Hook): + def point(self): + return HookPoint.ON_SUCCESS_LLM_CALL + +class OnErrorLLMCallHook(Hook): def point(self): - return HookPoint.POST_LLM_CALL + return HookPoint.ON_ERROR_LLM_CALL class OutputProcessHook(Hook): """Output process hook for processing output data for display.""" - __metaclass__ = abc.ABCMeta def point(self): return HookPoint.OUTPUT_PROCESS @@ -90,33 +107,3 @@ def process_output_content(self, content: str) -> str: processed content """ return content - - -class PreToolCallHook(Hook): - """Process in the hook point of the pre_tool_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.PRE_TOOL_CALL - - -class PostToolCallHook(Hook): - """Process in the hook point of the post_tool_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.POST_TOOL_CALL - -class PreTaskCallHook(Hook): - """Process in the hook point of the post_task_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.PRE_TASK_CALL - -class PostTaskCallHook(Hook): - """Process in the hook point of the post_task_call.""" - __metaclass__ = abc.ABCMeta - - def point(self): - return HookPoint.POST_TASK_CALL diff --git a/aworld/runners/hook/task_hooks.py b/aworld/runners/hook/task_hooks.py index be413377f..1060f720a 100644 --- a/aworld/runners/hook/task_hooks.py +++ b/aworld/runners/hook/task_hooks.py @@ -1,61 +1,28 @@ # coding: utf-8 # Copyright (c) 2025 inclusionAI. -import abc +from aworld.runners.hook.hooks import Hook, HookPoint -from aworld.core.context.base import Context -from aworld.core.event.base import Message -from aworld.runners.hook.hook_factory import HookFactory -from aworld.runners.hook.hooks import PostLLMCallHook, PreLLMCallHook, Hook -from aworld.utils.common import convert_to_snake +class OnRunHook(Hook): + def point(self): + return HookPoint.ON_RUN_TASK -@HookFactory.register(name="OnRunTaskProcessHook", - desc="OnRunTaskProcessHook") -class OnRunTaskProcessHook(Hook): - __metaclass__ = abc.ABCMeta - def name(self): - return convert_to_snake("OnRunTaskProcessHook") +class OnSuccessHook(Hook): + def point(self): + return HookPoint.ON_SUCCESS_TASK - async def exec(self, message: Message, context: Context = None) -> Message: - # get context - pass +class OnErrorHook(Hook): + def point(self): + return HookPoint.ON_ERROR_TASK -@HookFactory.register(name="OnSuccessTaskProcessHook", - desc="OnSuccessTaskProcessHook") -class OnSuccessTaskProcessHook(Hook): - __metaclass__ = abc.ABCMeta - def name(self): - return convert_to_snake("OnSuccessTaskProcessHook") +class OnStartHook(Hook): + def point(self): + return HookPoint.ON_START_TASK - async def exec(self, message: Message, context: Context = None) -> Message: - # get context - pass - -@HookFactory.register(name="OnErrorTaskProcessHook", - desc="OnErrorTaskProcessHook") -class OnErrorTaskProcessHook(Hook): - __metaclass__ = abc.ABCMeta - - def name(self): - return convert_to_snake("OnErrorTaskProcessHook") - - async def exec(self, message: Message, context: Context = None) -> Message: - # get context - pass - - -@HookFactory.register(name="OnFinishTaskProcessHook", - desc="OnFinishTaskProcessHook") -class OnFinishTaskProcessHook(Hook): - __metaclass__ = abc.ABCMeta - - def name(self): - return convert_to_snake("OnFinishTaskProcessHook") - - async def exec(self, message: Message, context: Context = None) -> Message: - # get context - pass \ No newline at end of file +class OnFinishHook(Hook): + def point(self): + return HookPoint.ON_FINISHED_TASK diff --git a/aworld/runners/hook/tool_hooks.py b/aworld/runners/hook/tool_hooks.py new file mode 100644 index 000000000..310b3e476 --- /dev/null +++ b/aworld/runners/hook/tool_hooks.py @@ -0,0 +1,28 @@ +# coding: utf-8 +# Copyright (c) 2025 inclusionAI. +from aworld.runners.hook.hooks import Hook, HookPoint + + +class OnStartToolCallHook(Hook): + def point(self): + return HookPoint.ON_START_TOOL_CALL + + +class OnFinishedToolCallHook(Hook): + def point(self): + return HookPoint.ON_FINISHED_TOOL_CALL + + +class OnToolCallHook(Hook): + def point(self): + return HookPoint.ON_TOOL_CALL + + +class OnSuccessToolCallHook(Hook): + def point(self): + return HookPoint.ON_SUCCESS_TOOL_CALL + + +class OnErrorToolCallHook(Hook): + def point(self): + return HookPoint.ON_ERROR_TOOL_CALL diff --git a/aworld/runners/task_manager.py b/aworld/runners/task_manager.py index 507828b30..c90e2a801 100644 --- a/aworld/runners/task_manager.py +++ b/aworld/runners/task_manager.py @@ -9,7 +9,7 @@ from aworld.core.task import Task from aworld.logs.util import logger -from aworld.core.common import TaskStatusValue +from aworld.core.common import TaskStatus class TaskManager: @@ -259,19 +259,19 @@ async def get_ready( async def get_pending(self, limit: Optional[int] = None) -> List[Task]: """Get all pending (INIT status) tasks.""" - return await self.list(status=TaskStatusValue.INIT, limit=limit) + return await self.list(status=TaskStatus.INIT, limit=limit) async def get_running(self, limit: Optional[int] = None) -> List[Task]: """Get all running tasks.""" - return await self.list(status=TaskStatusValue.RUNNING, limit=limit) + return await self.list(status=TaskStatus.RUNNING, limit=limit) async def get_completed(self, limit: Optional[int] = None) -> List[Task]: """Get all completed (SUCCESS status) tasks.""" - return await self.list(status=TaskStatusValue.SUCCESS, limit=limit) + return await self.list(status=TaskStatus.SUCCESS, limit=limit) async def get_failed(self, limit: Optional[int] = None) -> List[Task]: """Get all failed tasks.""" - return await self.list(status=TaskStatusValue.FAILED, limit=limit) + return await self.list(status=TaskStatus.FAILED, limit=limit) async def get_periodic(self) -> List[Task]: """Get all periodic tasks (with cron expression).""" @@ -374,10 +374,10 @@ async def get_statistics(self) -> dict: """ try: total = await self.count() - pending = await self.count(TaskStatusValue.INIT) - running = await self.count(TaskStatusValue.RUNNING) - completed = await self.count(TaskStatusValue.SUCCESS) - failed = await self.count(TaskStatusValue.FAILED) + pending = await self.count(TaskStatus.INIT) + running = await self.count(TaskStatus.RUNNING) + completed = await self.count(TaskStatus.SUCCESS) + failed = await self.count(TaskStatus.FAILED) return { "total": total, @@ -437,7 +437,7 @@ async def get_ready_tasks(self, """ current_time = current_time if current_time else time.time() try: - pending_tasks = await self.get_tasks_by_status(TaskStatusValue.INIT) + pending_tasks = await self.get_tasks_by_status(TaskStatus.INIT) ready_tasks = [] for task in pending_tasks: # Check if task has is_ready method (for ScheduledTask compatibility) diff --git a/aworld/schedule/executor.py b/aworld/schedule/executor.py new file mode 100644 index 000000000..48587ba72 --- /dev/null +++ b/aworld/schedule/executor.py @@ -0,0 +1,42 @@ +# coding: utf-8 +# Copyright (c) inclusionAI. +import time +from typing import List, Dict, Any, Optional + +from aworld.config import RunConfig +from aworld.core.common import TaskStatus +from aworld.core.task import Task +from aworld.logs.util import logger +from aworld.runners.runtime_engine import RuntimeEngine +from aworld.runners.utils import runtime_engine +from aworld.utils.common import sync_exec +from aworld.utils.run_util import exec_tasks + + +class RuntimeEngineExecutor: + """Wrapper for RuntimeEngine to execute scheduled tasks, provides execution and status management for each task.""" + + def __init__(self, engine: Optional[RuntimeEngine] = None, run_config: Optional[RunConfig] = None): + self.runtime_engine = engine or sync_exec(runtime_engine, run_config) + + async def execute(self, tasks: List[Task]) -> Dict[str, Any]: + """Execute a list of scheduled tasks using the runtime engine, updates task status and start time before execution.""" + if not tasks: + return {} + + funcs = [] + # Wrap each task in an async executor function + for scheduled_task in tasks: + async def task_execute(st=scheduled_task): + st.task_status = TaskStatus.RUNNING + st.started_at = time.time() + res = await exec_tasks(tasks=[st]) + st.task_status = res.get(st.id).status + st.completed_at = time.time() + return res + + funcs.append(task_execute) + + results = await self.runtime_engine.execute(funcs) + logger.info(f"{self.runtime_engine.name} execute {len(tasks)} tasks finished") + return results diff --git a/aworld/schedule/scheduler.py b/aworld/schedule/scheduler.py index 5d1a84e64..d6a7b1848 100644 --- a/aworld/schedule/scheduler.py +++ b/aworld/schedule/scheduler.py @@ -5,11 +5,12 @@ from typing import Optional, List, Callable, Awaitable, Dict, Set from aworld.logs.util import logger +from aworld.runners.hook.hooks import Hook from aworld.runners.hook.utils import run_hooks from aworld.runners.task_manager import TaskManager from aworld.schedule.strategy import ScheduleStrategy, create_strategy from aworld.schedule.types import ScheduledTask, ResourceQuota, ScheduledTaskStatistics -from aworld.core.common import TaskStatusValue +from aworld.core.common import TaskStatus from aworld.core.task import Task, TaskResponse @@ -18,7 +19,7 @@ class TaskScheduler: Task Scheduler for managing and executing scheduled tasks. The scheduler supports: - - Multiple scheduling strategies (FIFO, Priority, DAG, Auto) + - Multiple scheduling strategies (FIFO, Priority, DAG, Auto, ...) - Resource quota management - Concurrent task execution - Periodic task scheduling @@ -28,8 +29,7 @@ class TaskScheduler: Examples: # Create scheduler with storage from aworld.core.storage.inmemory_store import InmemoryStorage - storage = InmemoryStorage() - manager = TaskManager(storage=storage) + manager = TaskManager(storage=InmemoryStorage()) scheduler = TaskScheduler(task_manager=manager) # Add tasks @@ -42,10 +42,8 @@ async def execute_handler(task): await asyncio.sleep(1) return True - scheduler.set_executor(execute_handler) - # Run scheduler - await scheduler.run(max_iterations=10) + await scheduler.run() # Or run scheduler in background scheduler.start() @@ -61,7 +59,7 @@ def __init__( resource_quota: Optional[ResourceQuota] = None, max_concurrent: int = 10, poll_interval: float = 1.0, - enable_periodic: bool = True + hooks: Dict[str, Hook] = None ): """ Initialize TaskScheduler. @@ -73,7 +71,6 @@ def __init__( resource_quota: Resource quota for task execution max_concurrent: Maximum concurrent tasks poll_interval: Seconds between scheduling cycles - enable_periodic: Enable periodic task scheduling """ if task_manager is None: raise ValueError("task_manager is required") @@ -83,7 +80,6 @@ def __init__( self.resource_quota = resource_quota or ResourceQuota(max_concurrent=max_concurrent) self.poll_interval = poll_interval - self.enable_periodic = enable_periodic # Execution state self._executor: Optional[Callable[[Task], Awaitable[bool]]] = None @@ -95,12 +91,12 @@ def __init__( # Statistics self.statistics = ScheduledTaskStatistics() - # Lifecycle hooks - self._on_task_start: Optional[Callable[[Task], Awaitable[None]]] = None - self._on_task_complete: Optional[Callable[[Task, bool], Awaitable[None]]] = None - self._on_task_error: Optional[Callable[[Task, Exception], Awaitable[None]]] = None + @property + def executor(self): + return self._executor - def set_executor(self, executor: Callable[[Task], Awaitable[bool]]): + @executor.setter + def executor(self, executor: Callable[[Task], Awaitable[bool]]): """ Set the task executor function. @@ -113,7 +109,7 @@ async def my_executor(task): # ... do work ... return True - scheduler.set_executor(my_executor) + scheduler.executor = my_executor """ self._executor = executor @@ -178,7 +174,7 @@ async def execute(self, task: Task) -> TaskResponse: bool: True if successful """ if self._executor is None: - logger.error("No executor set. Use set_executor() to configure task execution.") + logger.error("No executor set. Use .executor to configure task execution.") return None # demo @@ -187,7 +183,7 @@ async def execute(self, task: Task) -> TaskResponse: # Mark task as started await self.task_manager.update_status( task.id, - TaskStatusValue.RUNNING, + TaskStatus.RUNNING, started_at=time.time() ) @@ -199,7 +195,7 @@ async def execute(self, task: Task) -> TaskResponse: # Mark task as completed await self.task_manager.update_status( task.id, - TaskStatusValue.SUCCESS if success else TaskStatusValue.FAILED, + TaskStatus.SUCCESS if success else TaskStatus.FAILED, completed_at=time.time() ) @@ -220,7 +216,7 @@ async def execute(self, task: Task) -> TaskResponse: # Mark task as failed await self.task_manager.update_status( task.id, - TaskStatusValue.FAILED, + TaskStatus.FAILED, completed_at=time.time() ) @@ -337,7 +333,7 @@ async def run(self, timeout: Optional[float] = None): # Check if no tasks to execute if executed == 0: - pending = await self.task_manager.count(TaskStatusValue.INIT) + pending = await self.task_manager.count(TaskStatus.INIT) if pending == 0: logger.info("No more tasks to execute") except Exception as e: @@ -480,7 +476,7 @@ async def cancel_task(self, task_id: str) -> bool: # Update task status return await self.task_manager.update_status( task_id, - TaskStatusValue.CANCELLED, + TaskStatus.CANCELLED, completed_at=time.time() ) @@ -493,4 +489,4 @@ async def cleanup(self, before_time: Optional[float] = None): """ removed = await self.task_manager.cleanup_completed(before_time=before_time) logger.info(f"Cleaned up {removed} old tasks") - return removed \ No newline at end of file + return removed diff --git a/aworld/schedule/task_graph.py b/aworld/schedule/task_graph.py index 70db235f8..f1fc2b8e1 100644 --- a/aworld/schedule/task_graph.py +++ b/aworld/schedule/task_graph.py @@ -3,10 +3,10 @@ from collections import deque from typing import Dict, List, Set, Optional -from aworld.core.common import TaskStatusValue +from aworld.core.common import TaskStatus from aworld.core.task import Task from aworld.logs.util import logger -from aworld.schedule.types import InstantTask +from aworld.schedule.types import SchedulableTask class TaskGraph: @@ -52,7 +52,7 @@ def get_ready_tasks(self) -> List[Task]: """Get all tasks that are ready to execute (no pending dependencies).""" ready_tasks = [] for task_id, node in self.nodes.items(): - if node.task_status == TaskStatusValue.INIT: + if node.task_status == TaskStatus.INIT: if task_id in self.predecessor: pending_predecessors = self.predecessor[task_id] - self.completed if not pending_predecessors: @@ -84,7 +84,7 @@ def mark_failed(self, task_id: str): if task_id in self.successor: for dependent_id in self.successor[task_id]: if dependent_id in self.nodes: - self.nodes[dependent_id].task_status = TaskStatusValue.FAILED + self.nodes[dependent_id].task_status = TaskStatus.FAILED if task_id in self.predecessor: del self.predecessor[task_id] @@ -224,7 +224,7 @@ def get_statistics(self) -> Dict: return { "total_tasks": len(self.nodes), "completed_tasks": len(self.completed), - "pending_tasks": sum(1 for n in self.nodes.values() if n.task_status == TaskStatusValue.INIT), + "pending_tasks": sum(1 for n in self.nodes.values() if n.task_status == TaskStatus.INIT), "ready_tasks": ready_count, "max_depth": len(self.get_execution_order()), "has_cycle": self.has_cycle(), @@ -248,7 +248,7 @@ def validate_dag(dag: 'TaskGraph') -> tuple[bool, Optional[str]]: return True, None @staticmethod - def validate_task_dependencies(task: InstantTask, existing_tasks: Set[str]) -> tuple[bool, Optional[str]]: + def validate_task_dependencies(task: SchedulableTask, existing_tasks: Set[str]) -> tuple[bool, Optional[str]]: """Validate task dependencies before adding to DAG.""" for dep_id in task.dependencies: if dep_id == task.id: diff --git a/aworld/schedule/types.py b/aworld/schedule/types.py index 95fe685e6..b6c8467ad 100644 --- a/aworld/schedule/types.py +++ b/aworld/schedule/types.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import List, Optional, Dict, Any, Set -from aworld.core.common import TaskStatus, TaskStatusValue +from aworld.core.common import TaskStatus from aworld.core.task import Task @@ -218,18 +218,18 @@ class ScheduledTaskStatistics: def update_from_task(self, task: Task): """Update state from task.""" - if task.task_status == TaskStatusValue.INIT: + if task.task_status == TaskStatus.INIT: self.pending_tasks += 1 - elif task.task_status == TaskStatusValue.RUNNING: + elif task.task_status == TaskStatus.RUNNING: self.running_tasks += 1 self.current_concurrent += 1 - elif task.task_status == TaskStatusValue.SUCCESS: + elif task.task_status == TaskStatus.SUCCESS: self.completed_tasks += 1 self.current_concurrent = max(0, self.current_concurrent - 1) - elif task.task_status == TaskStatusValue.FAILED: + elif task.task_status == TaskStatus.FAILED: self.failed_tasks += 1 self.current_concurrent = max(0, self.current_concurrent - 1) - elif task.task_status in (TaskStatusValue.CANCELLED, TaskStatusValue.INTERRUPTED, TaskStatusValue.TIMEOUT): + elif task.task_status in (TaskStatus.CANCELLED, TaskStatus.INTERRUPTED, TaskStatus.TIMEOUT): self.abnormal_tasks += 1 self.current_concurrent = max(0, self.current_concurrent - 1) From 429892bffa7d69256160bd8e3b0002fa2e4c9558 Mon Sep 17 00:00:00 2001 From: ganrunsheng Date: Thu, 5 Mar 2026 21:07:25 +0800 Subject: [PATCH 06/10] fix detail --- aworld/core/task.py | 10 +++---- .../evaluations/scorers/output_validators.py | 28 +++++-------------- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/aworld/core/task.py b/aworld/core/task.py index f4f0bea35..4c6fb53f8 100644 --- a/aworld/core/task.py +++ b/aworld/core/task.py @@ -1,11 +1,9 @@ # coding: utf-8 # Copyright (c) 2025 inclusionAI. import abc -import asyncio -import enum import uuid from dataclasses import dataclass, field -from typing import Any, Union, List, Dict, Callable, Optional, Literal, TYPE_CHECKING, AsyncGenerator +from typing import Any, Union, List, Dict, Callable, Optional, AsyncGenerator from aworld.core.event.base import Message from aworld.utils.serialized_util import to_serializable @@ -58,7 +56,7 @@ class Task: max_retry_count: int = field(default=0) timeout: int = field(default=0) observation: Optional[Observation] = field(default=None) - task_status: TaskStatus = field(default=TaskStatus.INIT) + task_status: str = field(default=TaskStatus.INIT) # streaming support streaming_mode: StreamingMode = field(default=None) @@ -109,9 +107,9 @@ class TaskResponse: time_cost: float | None = field(default=0.0) success: bool = field(default=False) msg: str | None = field(default=None) - trajectory: List[Dict[str, Any]]= field(default_factory=list) + trajectory: List[Dict[str, Any]] = field(default_factory=list) # task final status, e.g. success/failed/cancelled - status: TaskStatus | None = field(default=TaskStatus.SUCCESS) + status: str | None = field(default=TaskStatus.SUCCESS) def to_dict(self) -> Dict[str, Any]: return { diff --git a/aworld/evaluations/scorers/output_validators.py b/aworld/evaluations/scorers/output_validators.py index c76344789..4e56adb58 100644 --- a/aworld/evaluations/scorers/output_validators.py +++ b/aworld/evaluations/scorers/output_validators.py @@ -343,13 +343,6 @@ def build_judge_data(self, index: int, input: Any, output: Any) -> str: @scorer_register( MetricNames.OUTPUT_RELEVANCE, - model_config=ModelConfig( - llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")), - llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")), - llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))), - llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")), - llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")), - ) ) class OutputRelevanceScorer(OutputLlmScore): """Verify the correlation between the answer and the question @@ -369,13 +362,6 @@ def _build_judge_system_prompt(self) -> str: @scorer_register( MetricNames.OUTPUT_COMPLETENESS, - model_config=ModelConfig( - llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")), - llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")), - llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))), - llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")), - llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")), - ) ) class OutputCompletenessScorer(OutputLlmScore): """Verify the completeness of the answer @@ -406,13 +392,13 @@ def build_judge_data(self, index: int, input: EvalDataCase, output: Any) -> str: @scorer_register( MetricNames.OUTPUT_QUALITY, - model_config=ModelConfig( - llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")), - llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")), - llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))), - llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")), - llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")), - ) + # model_config=ModelConfig( + # llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")), + # llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")), + # llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))), + # llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")), + # llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")), + # ) ) class OutputQualityScorer(OutputLlmScore): """Comprehensive evaluation of answer quality From 691f29b71e63dc3297ce1f6536cd79300093ae46 Mon Sep 17 00:00:00 2001 From: "ck.hsq" Date: Thu, 5 Mar 2026 21:16:15 +0800 Subject: [PATCH 07/10] update type --- aworld/core/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aworld/core/task.py b/aworld/core/task.py index 4c6fb53f8..c0aa3a1bb 100644 --- a/aworld/core/task.py +++ b/aworld/core/task.py @@ -56,7 +56,7 @@ class Task: max_retry_count: int = field(default=0) timeout: int = field(default=0) observation: Optional[Observation] = field(default=None) - task_status: str = field(default=TaskStatus.INIT) + task_status: TaskStatus = field(default=TaskStatus.INIT) # streaming support streaming_mode: StreamingMode = field(default=None) From 39cc8b8413da71a63aca3c165e46d2a297bcfdd9 Mon Sep 17 00:00:00 2001 From: ganrunsheng Date: Fri, 6 Mar 2026 14:49:47 +0800 Subject: [PATCH 08/10] add executor param; add utility function for task exec --- aworld/schedule/executor.py | 42 -------- aworld/schedule/scheduler.py | 184 +++++++++++++++++------------------ aworld/schedule/types.py | 2 +- 3 files changed, 91 insertions(+), 137 deletions(-) delete mode 100644 aworld/schedule/executor.py diff --git a/aworld/schedule/executor.py b/aworld/schedule/executor.py deleted file mode 100644 index 48587ba72..000000000 --- a/aworld/schedule/executor.py +++ /dev/null @@ -1,42 +0,0 @@ -# coding: utf-8 -# Copyright (c) inclusionAI. -import time -from typing import List, Dict, Any, Optional - -from aworld.config import RunConfig -from aworld.core.common import TaskStatus -from aworld.core.task import Task -from aworld.logs.util import logger -from aworld.runners.runtime_engine import RuntimeEngine -from aworld.runners.utils import runtime_engine -from aworld.utils.common import sync_exec -from aworld.utils.run_util import exec_tasks - - -class RuntimeEngineExecutor: - """Wrapper for RuntimeEngine to execute scheduled tasks, provides execution and status management for each task.""" - - def __init__(self, engine: Optional[RuntimeEngine] = None, run_config: Optional[RunConfig] = None): - self.runtime_engine = engine or sync_exec(runtime_engine, run_config) - - async def execute(self, tasks: List[Task]) -> Dict[str, Any]: - """Execute a list of scheduled tasks using the runtime engine, updates task status and start time before execution.""" - if not tasks: - return {} - - funcs = [] - # Wrap each task in an async executor function - for scheduled_task in tasks: - async def task_execute(st=scheduled_task): - st.task_status = TaskStatus.RUNNING - st.started_at = time.time() - res = await exec_tasks(tasks=[st]) - st.task_status = res.get(st.id).status - st.completed_at = time.time() - return res - - funcs.append(task_execute) - - results = await self.runtime_engine.execute(funcs) - logger.info(f"{self.runtime_engine.name} execute {len(tasks)} tasks finished") - return results diff --git a/aworld/schedule/scheduler.py b/aworld/schedule/scheduler.py index d6a7b1848..116e784a5 100644 --- a/aworld/schedule/scheduler.py +++ b/aworld/schedule/scheduler.py @@ -2,16 +2,19 @@ # Copyright (c) inclusionAI. import asyncio import time -from typing import Optional, List, Callable, Awaitable, Dict, Set +from typing import Optional, List, Callable, Awaitable, Dict, Set, Any from aworld.logs.util import logger from aworld.runners.hook.hooks import Hook from aworld.runners.hook.utils import run_hooks +from aworld.runners.runtime_engine import RuntimeEngine from aworld.runners.task_manager import TaskManager +from aworld.runners.utils import runtime_engine from aworld.schedule.strategy import ScheduleStrategy, create_strategy -from aworld.schedule.types import ScheduledTask, ResourceQuota, ScheduledTaskStatistics +from aworld.schedule.types import ScheduledTask, ResourceQuota, TaskStatistics from aworld.core.common import TaskStatus from aworld.core.task import Task, TaskResponse +from aworld.utils.run_util import exec_tasks class TaskScheduler: @@ -52,14 +55,15 @@ async def execute_handler(task): """ def __init__( - self, - task_manager: TaskManager, - strategy: Optional[ScheduleStrategy] = None, - strategy_type: str = 'auto', - resource_quota: Optional[ResourceQuota] = None, - max_concurrent: int = 10, - poll_interval: float = 1.0, - hooks: Dict[str, Hook] = None + self, + task_manager: TaskManager = TaskManager(), + strategy: Optional[ScheduleStrategy] = None, + strategy_type: str = 'auto', + executor: Optional[Callable[[Task], Awaitable[TaskResponse]]] = None, + resource_quota: Optional[ResourceQuota] = None, + max_concurrent: int = 10, + poll_interval: float = 1.0, + hooks: Dict[str, Hook] = None ): """ Initialize TaskScheduler. @@ -67,29 +71,26 @@ def __init__( Args: task_manager: TaskManager instance (required) strategy: Custom scheduling strategy (optional) - strategy_type: Type of strategy if strategy is None ('auto', 'fifo', 'priority', 'dag') + strategy_type: Type of schedule strategy ('auto', 'fifo', 'priority', 'dag', 'custom_name') + executor: Async function to execute a task (optional, default uses RuntimeEngine) resource_quota: Resource quota for task execution max_concurrent: Maximum concurrent tasks poll_interval: Seconds between scheduling cycles """ - if task_manager is None: - raise ValueError("task_manager is required") - self.task_manager = task_manager self.strategy = strategy or create_strategy(strategy_type) self.resource_quota = resource_quota or ResourceQuota(max_concurrent=max_concurrent) - self.poll_interval = poll_interval # Execution state - self._executor: Optional[Callable[[Task], Awaitable[bool]]] = None + self._executor: Optional[Callable[[Task], Awaitable[TaskResponse]]] = executor or execute_schedulable_tasks self._running = False self._scheduler_task: Optional[asyncio.Task] = None self._active_tasks: Dict[str, asyncio.Task] = {} self._stop_event = asyncio.Event() # Statistics - self.statistics = ScheduledTaskStatistics() + self.statistics = TaskStatistics() @property def executor(self): @@ -107,43 +108,12 @@ def executor(self, executor: Callable[[Task], Awaitable[bool]]): async def my_executor(task): print(f"Executing {task.name}") # ... do work ... - return True + return TaskResponse(success=True, ...) scheduler.executor = my_executor """ self._executor = executor - async def add_task(self, task: ScheduledTask, overwrite: bool = True) -> bool: - """ - Add a task to the scheduler. - - Args: - task: Task to add - overwrite: Whether to overwrite existing task - - Returns: - bool: True if successful - """ - success = await self.task_manager.add_task(task, overwrite=overwrite) - - if success: - logger.info(f"Added task to scheduler: {task.id} ({task.name})") - - return success - - async def add_tasks(self, tasks: List[ScheduledTask], overwrite: bool = True) -> int: - """ - Add multiple tasks to the scheduler. - - Args: - tasks: List of tasks to add - overwrite: Whether to overwrite existing tasks - - Returns: - int: Number of successfully added tasks - """ - return await self.task_manager.add_batch(tasks, overwrite=overwrite) - async def schedule(self, tasks: Optional[List[ScheduledTask]] = None) -> List[List[ScheduledTask]]: """Schedule tasks into execution batches using the configured strategy. @@ -173,12 +143,10 @@ async def execute(self, task: Task) -> TaskResponse: Returns: bool: True if successful """ - if self._executor is None: - logger.error("No executor set. Use .executor to configure task execution.") - return None + if not self._executor: + res = await execute_schedulable_tasks(await runtime_engine(), [task]) + return res.get(task.id) - # demo - await run_hooks(task.context, "task_start", "scheduler", payload=task) try: # Mark task as started await self.task_manager.update_status( @@ -190,26 +158,26 @@ async def execute(self, task: Task) -> TaskResponse: logger.info(f"Executing task: {task.id} ({task.name})") # Execute task - success = await self._executor(task) + result: TaskResponse = await self._executor(task) # Mark task as completed await self.task_manager.update_status( task.id, - TaskStatus.SUCCESS if success else TaskStatus.FAILED, + TaskStatus.SUCCESS if result.success else TaskStatus.FAILED, completed_at=time.time() ) # Update completed cache for dependency tracking - if success: + if result.success: self.task_manager.mark_completed(task.id) - logger.info(f"Task {'completed' if success else 'failed'}: {task.id}") + logger.info(f"Task {'completed' if result.success else 'failed'}: {task.id}") # Handle periodic tasks - if success and hasattr(task, 'is_periodic') and task.is_periodic: + if result.success and hasattr(task, 'is_periodic') and task.is_periodic: await self._reschedule_periodic_task(task) - return None + return result except Exception as e: logger.error(f"Error executing task {task.id}: {e}") @@ -219,8 +187,7 @@ async def execute(self, task: Task) -> TaskResponse: TaskStatus.FAILED, completed_at=time.time() ) - - return None + return TaskResponse() async def _reschedule_periodic_task(self, task: ScheduledTask): """Reschedule a periodic task for next execution.""" @@ -238,7 +205,7 @@ async def _reschedule_periodic_task(self, task: ScheduledTask): return # Reset for next execution - task.task_status = "init" + task.task_status = TaskStatus.INIT task.started_at = None task.completed_at = None @@ -250,18 +217,17 @@ async def _reschedule_periodic_task(self, task: ScheduledTask): except Exception as e: logger.error(f"Failed to reschedule periodic task {task.id}: {e}") - async def execute_batch(self, batch: List[Task]) -> List[bool]: - """ - Execute a batch of tasks concurrently. + async def execute_batch(self, batch: List[Task]) -> Dict[str, TaskResponse]: + """Execute a batch of tasks concurrently. Args: batch: Batch of tasks to execute Returns: - List of success flags for each task + Dict of response for each task """ if not batch: - return [] + return {} # Check resource availability current_concurrent = len(self._active_tasks) @@ -269,7 +235,7 @@ async def execute_batch(self, batch: List[Task]) -> List[bool]: available_slots = self.resource_quota.max_concurrent - current_concurrent if available_slots <= 0: logger.warning("Max concurrent tasks reached, skipping batch") - return [False] * len(batch) + return {} # Limit batch size to available slots batch = batch[:available_slots] @@ -278,12 +244,10 @@ async def execute_batch(self, batch: List[Task]) -> List[bool]: tasks = [self.execute(task) for task in batch] results = await asyncio.gather(*tasks, return_exceptions=True) - # Convert exceptions to False - return [r if isinstance(r, bool) else False for r in results] + return {res.id: res for res in results} - async def run_once(self) -> int: - """ - Run one scheduling cycle. + async def _run_once(self) -> int: + """Run one scheduling cycle. Returns: int: Number of tasks executed @@ -307,17 +271,12 @@ async def run_once(self) -> int: return executed_count async def run(self, timeout: Optional[float] = None): - """ - Run the scheduler for a limited time or iterations. + """Run the scheduler for a limited time or iterations. Args: - max_iterations: Maximum number of scheduling cycles (None = unlimited) timeout: Maximum time to run in seconds (None = unlimited) Examples: - # Run for 10 iterations - await scheduler.run(max_iterations=10) - # Run for 60 seconds await scheduler.run(timeout=60) @@ -329,7 +288,7 @@ async def run(self, timeout: Optional[float] = None): try: # Run one cycle - executed = await self.run_once() + executed = await self._run_once() # Check if no tasks to execute if executed == 0: @@ -361,8 +320,7 @@ async def _scheduler_loop(self): pass # Run one scheduling cycle - await self.run_once() - + await self._run_once() except Exception as e: logger.error(f"Error in scheduler loop: {e}") finally: @@ -370,8 +328,7 @@ async def _scheduler_loop(self): logger.info("Scheduler loop stopped") def start(self): - """ - Start the scheduler in background. + """Start the scheduler in background. Examples: scheduler.start() @@ -389,8 +346,7 @@ def start(self): logger.info("Scheduler started in background") async def stop(self, wait: bool = True): - """ - Stop the background scheduler. + """Stop the background scheduler. Args: wait: Wait for scheduler to complete current cycle @@ -415,7 +371,7 @@ async def stop(self, wait: bool = True): self._active_tasks.clear() - logger.info("Scheduler stopped") + logger.info("Scheduler stopped!") @property def is_running(self) -> bool: @@ -423,8 +379,7 @@ def is_running(self) -> bool: return self._running async def get_statistics(self) -> dict: - """ - Get scheduler statistics. + """Get scheduler statistics. Returns: Dictionary with scheduler statistics @@ -458,9 +413,28 @@ async def resume(self): else: logger.warning("Scheduler not paused or not running") - async def cancel_task(self, task_id: str) -> bool: + async def add_task(self, task: ScheduledTask, overwrite: bool = True) -> bool: + """Add a task to the scheduler. + + Args: + task: Task to add + overwrite: Whether to overwrite existing task + + Returns: + bool: True if successful """ - Cancel a scheduled task. + success = await self.task_manager.add_task(task, overwrite=overwrite) + + if success: + logger.info(f"Added task to scheduler: {task.id} ({task.name})") + return success + + async def add_tasks(self, tasks: List[ScheduledTask], overwrite: bool = True) -> int: + """Add multiple tasks to the scheduler.""" + return await self.task_manager.add_batch(tasks, overwrite=overwrite) + + async def cancel_task(self, task_id: str) -> bool: + """Cancel a scheduled task. Args: task_id: ID of task to cancel @@ -481,8 +455,7 @@ async def cancel_task(self, task_id: str) -> bool: ) async def cleanup(self, before_time: Optional[float] = None): - """ - Clean up old completed tasks. + """Clean up old completed tasks. Args: before_time: Remove tasks completed before this time @@ -490,3 +463,26 @@ async def cleanup(self, before_time: Optional[float] = None): removed = await self.task_manager.cleanup_completed(before_time=before_time) logger.info(f"Cleaned up {removed} old tasks") return removed + + +async def execute_schedulable_tasks(runtime_engine: RuntimeEngine, tasks: List[Task]) -> Dict[str, Any]: + """Execute a list of scheduled tasks using the runtime engine, updates task status and start time before execution.""" + if not tasks: + return {} + + funcs = [] + # Wrap each task in an async executor function + for scheduled_task in tasks: + async def task_execute(st=scheduled_task): + st.task_status = TaskStatus.RUNNING + st.started_at = time.time() + res = await exec_tasks(tasks=[st]) + st.task_status = res.get(st.id).status + st.completed_at = time.time() + return res + + funcs.append(task_execute) + + results = await runtime_engine.execute(funcs) + logger.info(f"{runtime_engine.name} execute {len(tasks)} tasks finished") + return results diff --git a/aworld/schedule/types.py b/aworld/schedule/types.py index b6c8467ad..e262717bc 100644 --- a/aworld/schedule/types.py +++ b/aworld/schedule/types.py @@ -197,7 +197,7 @@ def _calculate_next_cron_time(self, from_time: Optional[float] = None) -> Option @dataclass -class ScheduledTaskStatistics: +class TaskStatistics: """Global scheduling state.""" total_tasks: int = 0 pending_tasks: int = 0 From a4197aef17609020d80f4ac8157e09828eafb039 Mon Sep 17 00:00:00 2001 From: ganrunsheng Date: Mon, 9 Mar 2026 13:44:15 +0800 Subject: [PATCH 09/10] fix some detail --- aworld/memory/utils.py | 16 ---------------- aworld/schedule/scheduler.py | 6 ++++-- aworld/schedule/strategy.py | 4 ++-- tests/memory/agent/self_evolving_agent.py | 14 +++++++++++++- 4 files changed, 19 insertions(+), 21 deletions(-) delete mode 100644 aworld/memory/utils.py diff --git a/aworld/memory/utils.py b/aworld/memory/utils.py deleted file mode 100644 index 1b094df73..000000000 --- a/aworld/memory/utils.py +++ /dev/null @@ -1,16 +0,0 @@ - - -from aworld.core.memory import MemoryItem - - -def build_history_context(history_messages: list[MemoryItem]) -> str: - """ - Build history context from history messages. - """ - history_context = "" - for message in history_messages: - if message.role == "user": - history_context += f"User: {message.content}\n" - else: - history_context += f"Agent: {message.content}\n" - return history_context \ No newline at end of file diff --git a/aworld/schedule/scheduler.py b/aworld/schedule/scheduler.py index 116e784a5..87ed778e3 100644 --- a/aworld/schedule/scheduler.py +++ b/aworld/schedule/scheduler.py @@ -97,7 +97,7 @@ def executor(self): return self._executor @executor.setter - def executor(self, executor: Callable[[Task], Awaitable[bool]]): + def executor(self, executor: Callable[[Task], Awaitable[TaskResponse]]): """ Set the task executor function. @@ -477,7 +477,9 @@ async def task_execute(st=scheduled_task): st.task_status = TaskStatus.RUNNING st.started_at = time.time() res = await exec_tasks(tasks=[st]) - st.task_status = res.get(st.id).status + task_response = res.get(st.id) + if task_response: + st.task_status = task_response.status st.completed_at = time.time() return res diff --git a/aworld/schedule/strategy.py b/aworld/schedule/strategy.py index 98a819c26..9e4b663a4 100644 --- a/aworld/schedule/strategy.py +++ b/aworld/schedule/strategy.py @@ -163,7 +163,7 @@ def register_strategy(strategy_type: str, strategy: Type[ScheduleStrategy]) -> b def create_strategy(strategy_type: Optional[str] = None, **kwargs) -> ScheduleStrategy: """Create strategy instance.""" - return STRATEGY_MAP.get(strategy_type, AutoStrategy())(**kwargs) + return STRATEGY_MAP.get(strategy_type, AutoStrategy)(**kwargs) class TaskQueue(ABC): @@ -239,7 +239,7 @@ def __init__(self): def push(self, task: ScheduledTask): # (priority_value, counter, task) priority_value = task.priority - heapq.heappush(self.heap, (priority_value, self.counter, task)) + heapq.heappush(self.heap, (-priority_value, self.counter, task)) self.counter += 1 def pop(self) -> Optional[ScheduledTask]: diff --git a/tests/memory/agent/self_evolving_agent.py b/tests/memory/agent/self_evolving_agent.py index fadc69541..ce2d9837f 100644 --- a/tests/memory/agent/self_evolving_agent.py +++ b/tests/memory/agent/self_evolving_agent.py @@ -16,7 +16,6 @@ from aworld.memory.main import MemoryFactory from aworld.memory.models import LongTermMemoryTriggerParams, MemoryAIMessage, MessageMetadata, UserProfile, \ MemoryHumanMessage -from aworld.memory.utils import build_history_context from aworld.output import AworldUI from aworld.output.utils import load_workspace from aworld.prompt import Prompt @@ -25,6 +24,19 @@ from tests.memory.prompts import SELF_EVOLVING_USER_INPUT_REWRITE_PROMPT, RESEARCH_PROMPT +def build_history_context(history_messages: list[MemoryItem]) -> str: + """ + Build history context from history messages. + """ + history_context = "" + for message in history_messages: + if message.role == "user": + history_context += f"User: {message.content}\n" + else: + history_context += f"Agent: {message.content}\n" + return history_context + + class SuperAgent: """ Super agent From 55f6cc0e9b41ed445a59e87437df39771a79f77e Mon Sep 17 00:00:00 2001 From: ganrunsheng Date: Mon, 9 Mar 2026 16:22:52 +0800 Subject: [PATCH 10/10] add list and delete api --- aworld/runners/task_manager.py | 2 +- aworld/schedule/scheduler.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/aworld/runners/task_manager.py b/aworld/runners/task_manager.py index c90e2a801..bfbe7f440 100644 --- a/aworld/runners/task_manager.py +++ b/aworld/runners/task_manager.py @@ -35,7 +35,7 @@ class TaskManager: # List tasks all_tasks = await manager.list() - pending_tasks = await manager.list(status=TaskStatusValue.INIT) + pending_tasks = await manager.list(status=TaskStatus.INIT) # Get ready tasks ready = await manager.get_ready() diff --git a/aworld/schedule/scheduler.py b/aworld/schedule/scheduler.py index 87ed778e3..4ad6dfc3c 100644 --- a/aworld/schedule/scheduler.py +++ b/aworld/schedule/scheduler.py @@ -425,14 +425,25 @@ async def add_task(self, task: ScheduledTask, overwrite: bool = True) -> bool: """ success = await self.task_manager.add_task(task, overwrite=overwrite) - if success: - logger.info(f"Added task to scheduler: {task.id} ({task.name})") + logger.info(f"Added task to scheduler {success}: {task.id} ({task.name})") return success async def add_tasks(self, tasks: List[ScheduledTask], overwrite: bool = True) -> int: """Add multiple tasks to the scheduler.""" return await self.task_manager.add_batch(tasks, overwrite=overwrite) + async def list_tasks(self, status: str = None, limit: Optional[int] = None, offset: int = 0) -> List[Task]: + tasks = await self.task_manager.list(status=status, limit=limit, offset=offset) + + logger.info(f"found {len(tasks)} tasks") + return tasks + + async def delete_task(self, task_id: str) -> bool: + success = await self.task_manager.delete_task(task_id=task_id) + + logger.info(f"Deleted task {success}: {task_id}") + return success + async def cancel_task(self, task_id: str) -> bool: """Cancel a scheduled task. @@ -481,7 +492,7 @@ async def task_execute(st=scheduled_task): if task_response: st.task_status = task_response.status st.completed_at = time.time() - return res + return task_response funcs.append(task_execute)