Skip to content

Commit 48baf7f

Browse files
authored
Create supervisor.py
1 parent 3f0c034 commit 48baf7f

File tree

1 file changed

+311
-0
lines changed

1 file changed

+311
-0
lines changed

core/agent/supervisor.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
"""
2+
Supervisor Agent - Central Orchestrator for Agent Swarm Coordination
3+
"""
4+
5+
from __future__ import annotations
6+
import asyncio
7+
from collections import deque
8+
from datetime import datetime, timedelta
9+
from typing import Dict, List, Optional, Tuple, Any
10+
import uuid
11+
import numpy as np
12+
import psutil
13+
from pydantic import BaseModel, Field, validator
14+
from .base import BaseAgent, AgentMessage, AgentID, AgentConfig, AgentRegistry
15+
from .worker import WorkerAgent, WorkerMetrics, TaskResult, TaskRequest
16+
17+
# Custom Types
18+
SwarmState = Dict[AgentID, np.ndarray] # State vectors of all agents
19+
PolicyVector = np.ndarray # Output from RL policy network
20+
21+
class SupervisorConfig(AgentConfig):
22+
"""Extended configuration for supervisor agents"""
23+
swarm_size_limit: int = Field(1000, gt=0)
24+
heartbeat_interval: int = 30 # seconds
25+
failure_threshold: int = 3 # Consecutive failures before remediation
26+
scheduling_algorithm: str = "rl_priority" # Options: rr, priority, rl_priority
27+
resource_weights: Dict[str, float] = {"cpu": 1.0, "mem_gb": 0.5}
28+
29+
@validator('scheduling_algorithm')
30+
def validate_algorithm(cls, v):
31+
allowed = ["rr", "priority", "rl_priority"]
32+
if v not in allowed:
33+
raise ValueError(f"Algorithm must be one of {allowed}")
34+
return v
35+
36+
class SwarmHealthReport(BaseModel):
37+
"""Global swarm health metrics"""
38+
total_agents: int
39+
active_workers: int
40+
avg_cpu_util: float
41+
avg_mem_util: float
42+
pending_tasks: int
43+
dead_agents: List[AgentID]
44+
45+
class TaskAssignment(BaseModel):
46+
"""Directive for task distribution"""
47+
task_id: str = Field(default_factory=lambda: f"task_{uuid.uuid4().hex[:8]}")
48+
worker_id: AgentID
49+
payload: Dict[str, Any]
50+
deadline: datetime
51+
priority: int = 1
52+
53+
class SupervisorAgent(BaseAgent):
54+
"""
55+
Central coordination agent for swarm management
56+
57+
Key Responsibilities:
58+
- Global state maintenance
59+
- RL-driven scheduling
60+
- Fault detection & recovery
61+
- Resource optimization
62+
- Swarm autoscaling
63+
"""
64+
65+
def __init__(self, agent_id: AgentID):
66+
super().__init__(agent_id)
67+
self.config = SupervisorConfig()
68+
self._swarm_state: SwarmState = {}
69+
self._task_queue = deque(maxlen=10000)
70+
self._failure_counts: Dict[AgentID, int] = {}
71+
self._policy_network = self._init_policy_network()
72+
self._last_heartbeat = datetime.utcnow()
73+
74+
async def _process_message(self, message: AgentMessage) -> MessagePayload:
75+
"""Handle swarm coordination messages"""
76+
if message.payload_type == "TaskResult":
77+
return await self._handle_task_result(TaskResult(**message.payload))
78+
elif message.payload_type == "WorkerMetrics":
79+
return await self._update_swarm_state(message.sender, message.payload)
80+
return {"status": "unhandled_message_type"}
81+
82+
async def _handle_task_result(self, result: TaskResult) -> Dict[str, Any]:
83+
"""Process task completion/failure events"""
84+
if not result.success:
85+
self._failure_counts[result.worker_id] = \
86+
self._failure_counts.get(result.worker_id, 0) + 1
87+
await self._trigger_remediation(result.worker_id)
88+
return {"action": "acknowledged"}
89+
90+
async def _update_swarm_state(self, agent_id: AgentID, metrics: Dict) -> Dict:
91+
"""Maintain real-time swarm state matrix"""
92+
state_vector = np.array([
93+
metrics["cpu_usage"],
94+
metrics["mem_usage_gb"],
95+
metrics["active_tasks"],
96+
metrics["queue_size"],
97+
datetime.utcnow().timestamp()
98+
])
99+
self._swarm_state[agent_id] = state_vector
100+
return {"status": "state_updated"}
101+
102+
async def _execute_policy(self, state: NDArray) -> NDArray:
103+
"""Generate swarm-level coordination directives"""
104+
# Convert swarm state to policy input tensor
105+
state_tensor = np.stack(list(self._swarm_state.values()))
106+
async with self._policy_lock:
107+
policy_output = self._policy_network.predict(state_tensor)
108+
return policy_output
109+
110+
def _init_policy_network(self) -> PolicyNetwork:
111+
"""Initialize RL policy model (placeholder implementation)"""
112+
class MockPolicyNetwork:
113+
def predict(self, state: np.ndarray) -> np.ndarray:
114+
return np.random.rand(state.shape[0], 5) # 5 actions per agent
115+
return MockPolicyNetwork()
116+
117+
async def _coordinate_swarm(self) -> None:
118+
"""Main coordination loop"""
119+
while self._is_running:
120+
# 1. Check swarm health
121+
health_report = self._generate_health_report()
122+
123+
# 2. Execute RL policy
124+
policy_vector = await self._execute_policy(health_report)
125+
126+
# 3. Dispatch tasks
127+
await self._dispatch_tasks(policy_vector)
128+
129+
# 4. Handle autoscaling
130+
if len(self._swarm_state) < self.config.swarm_size_limit:
131+
await self._scale_swarm()
132+
133+
# 5. Failure recovery
134+
await self._recover_failed_agents()
135+
136+
await asyncio.sleep(1)
137+
138+
async def _dispatch_tasks(self, policy_vector: PolicyVector) -> None:
139+
"""Distribute tasks based on policy output"""
140+
for agent_id, actions in zip(self._swarm_state.keys(), policy_vector):
141+
if agent_id not in WorkerAgent.get_worker_metrics():
142+
continue
143+
144+
# Decode policy actions
145+
task_capacity = int(actions[0] * 10) # Max 10 tasks per dispatch
146+
for _ in range(task_capacity):
147+
if self._task_queue:
148+
task = self._task_queue.popleft()
149+
assignment = TaskAssignment(
150+
worker_id=agent_id,
151+
payload=task.payload,
152+
deadline=datetime.utcnow() + timedelta(seconds=task.timeout)
153+
)
154+
await self._send_task_assignment(assignment)
155+
156+
async def _send_task_assignment(self, assignment: TaskAssignment) -> None:
157+
"""Direct task assignment to target worker"""
158+
try:
159+
await self._send_message(
160+
receiver=assignment.worker_id,
161+
payload_type="TaskAssignment",
162+
payload=assignment.dict()
163+
)
164+
except AgentNetworkError as e:
165+
self._logger.error(f"Failed to assign task {assignment.task_id}: {e}")
166+
self._task_queue.append(assignment) # Requeue failed assignment
167+
168+
async def _scale_swarm(self) -> None:
169+
"""Autoscale worker agents based on load"""
170+
pending_tasks = len(self._task_queue)
171+
current_workers = len(WorkerAgent.get_worker_metrics())
172+
173+
if pending_tasks > current_workers * 5: # Scale-up threshold
174+
scale_count = min(
175+
(pending_tasks // 5) - current_workers,
176+
self.config.swarm_size_limit - current_workers
177+
)
178+
for _ in range(scale_count):
179+
worker_id = f"worker-{uuid.uuid4().hex[:8]}"
180+
await self._deploy_new_worker(worker_id)
181+
182+
async def _deploy_new_worker(self, worker_id: AgentID) -> None:
183+
"""Orchestrate new worker deployment (Kubernetes integration example)"""
184+
# TODO: Implement actual deployment logic
185+
worker = WorkerAgent(worker_id)
186+
self._registry.register(worker)
187+
asyncio.create_task(worker.start())
188+
189+
async def _recover_failed_agents(self) -> None:
190+
"""Handle agent failure recovery"""
191+
for agent_id, count in self._failure_counts.items():
192+
if count >= self.config.failure_threshold:
193+
await self._restart_agent(agent_id)
194+
self._failure_counts[agent_id] = 0
195+
196+
async def _restart_agent(self, agent_id: AgentID) -> None:
197+
"""Agent restart procedure"""
198+
self._logger.warning(f"Restarting agent {agent_id}")
199+
old_agent = self._registry.get(agent_id)
200+
if old_agent:
201+
await old_agent.shutdown()
202+
del self._swarm_state[agent_id]
203+
204+
new_agent = WorkerAgent(agent_id)
205+
self._registry.register(new_agent)
206+
asyncio.create_task(new_agent.start())
207+
208+
def _generate_health_report(self) -> SwarmHealthReport:
209+
"""Generate system-wide health metrics"""
210+
worker_metrics = WorkerAgent.get_worker_metrics()
211+
return SwarmHealthReport(
212+
total_agents=len(self._swarm_state),
213+
active_workers=len(worker_metrics),
214+
avg_cpu_util=(
215+
sum(w.cpu_usage for w in worker_metrics.values())
216+
/ len(worker_metrics) if worker_metrics else 0
217+
),
218+
avg_mem_util=(
219+
sum(w.mem_usage_gb for w in worker_metrics.values())
220+
/ len(worker_metrics) if worker_metrics else 0
221+
),
222+
pending_tasks=len(self._task_queue),
223+
dead_agents=[
224+
aid for aid in self._swarm_state
225+
if aid not in worker_metrics
226+
]
227+
)
228+
229+
async def submit_task(self, task: TaskRequest) -> str:
230+
"""Public API for task submission"""
231+
self._task_queue.append(task)
232+
return task.task_id
233+
234+
async def shutdown_swarm(self) -> None:
235+
"""Graceful swarm shutdown"""
236+
for agent in self._registry.get_all_agents():
237+
if isinstance(agent, WorkerAgent):
238+
await agent.shutdown()
239+
await super().shutdown()
240+
241+
@classmethod
242+
def get_global_health(cls) -> SwarmHealthReport:
243+
"""Get current swarm health status"""
244+
supervisors = [
245+
agent for agent in cls._registry.values()
246+
if isinstance(agent, SupervisorAgent)
247+
]
248+
if not supervisors:
249+
raise ValueError("No active supervisor")
250+
return supervisors[0]._generate_health_report()
251+
252+
# Kubernetes-enhanced Supervisor
253+
class K8sSupervisor(SupervisorAgent):
254+
"""Supervisor with Kubernetes cluster integration"""
255+
256+
async def _deploy_new_worker(self, worker_id: AgentID) -> None:
257+
"""Deploy workers using Kubernetes API"""
258+
from kubernetes import client, config # Requires k8s SDK
259+
260+
# Load cluster config
261+
config.load_incluster_config()
262+
api = client.AppsV1Api()
263+
264+
# Create new worker deployment
265+
deployment = client.V1Deployment(
266+
metadata=client.V1ObjectMeta(name=f"aelion-worker-{worker_id}"),
267+
spec=client.V1DeploymentSpec(
268+
replicas=1,
269+
template=client.V1PodTemplateSpec(
270+
spec=client.V1PodSpec(
271+
containers=[
272+
client.V1Container(
273+
name="worker",
274+
image="aelionai/worker:latest",
275+
env=[
276+
client.V1EnvVar(
277+
name="AGENT_ID",
278+
value=worker_id
279+
)
280+
]
281+
)
282+
]
283+
)
284+
)
285+
)
286+
)
287+
288+
api.create_namespaced_deployment(
289+
namespace="aelion",
290+
body=deployment
291+
)
292+
self._logger.info(f"Deployed worker {worker_id} via Kubernetes")
293+
294+
async def _restart_agent(self, agent_id: AgentID) -> None:
295+
"""Kubernetes pod restart logic"""
296+
from kubernetes import client, config
297+
298+
config.load_incluster_config()
299+
core_api = client.CoreV1Api()
300+
301+
pods = core_api.list_namespaced_pod(
302+
namespace="aelion",
303+
label_selector=f"agent-id={agent_id}"
304+
)
305+
306+
if pods.items:
307+
core_api.delete_namespaced_pod(
308+
name=pods.items[0].metadata.name,
309+
namespace="aelion"
310+
)
311+
self._logger.info(f"Restarted K8s pod for agent {agent_id}")

0 commit comments

Comments
 (0)