"""API endpoints, SSE, polling.""" import asyncio import hashlib import json import logging from asyncio import Queue from pathlib import Path from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect from starlette.responses import StreamingResponse import httpx from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth from .runtime import Runtime, TRACE_FILE log = logging.getLogger("runtime") # Active runtime reference (set by WS endpoint) _active_runtime: Runtime | None = None # SSE subscribers _sse_subscribers: list[Queue] = [] # Async message pipeline state _pipeline_task: asyncio.Task | None = None _pipeline_result: dict = {"status": "idle"} _pipeline_id: int = 0 def _broadcast_sse(event: dict): """Push an event to all SSE subscribers + update pipeline progress.""" for q in _sse_subscribers: try: q.put_nowait(event) except asyncio.QueueFull: pass # Update pipeline progress from HUD events if _pipeline_result.get("status") == "running": node = event.get("node", "") evt = event.get("event", "") if node and evt in ("thinking", "perceived", "decided", "streaming", "tool_exec", "interpreted", "updated"): _pipeline_result["stage"] = node _pipeline_result["event"] = evt def _state_hash() -> str: if not _active_runtime: return "no_session" raw = json.dumps({ "mem": _active_runtime.memorizer.state, "hlen": len(_active_runtime.history), }, sort_keys=True) return hashlib.md5(raw.encode()).hexdigest()[:12] def register_routes(app): """Register all API routes on the FastAPI app.""" @app.get("/health") async def health(): return {"status": "ok"} @app.get("/auth/config") async def auth_config(): from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED return { "enabled": AUTH_ENABLED, "issuer": ZITADEL_ISSUER, "clientId": ZITADEL_CLIENT_ID, "projectId": ZITADEL_PROJECT_ID, } @app.websocket("/ws") async def ws_endpoint(ws: WebSocket, token: str | None = Query(None), access_token: str | None = Query(None)): global _active_runtime user_claims = {"sub": "anonymous"} if AUTH_ENABLED and token: try: user_claims = await _validate_token(token) if not user_claims.get("name") and access_token: async with httpx.AsyncClient() as client: resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo", headers={"Authorization": f"Bearer {access_token}"}) if resp.status_code == 200: info = resp.json() log.info(f"[auth] userinfo enrichment: {info}") user_claims["name"] = info.get("name") user_claims["preferred_username"] = info.get("preferred_username") user_claims["email"] = info.get("email") except HTTPException: await ws.close(code=4001, reason="Invalid token") return origin = ws.headers.get("origin", ws.headers.get("host", "")) await ws.accept() runtime = Runtime(ws, user_claims=user_claims, origin=origin, broadcast=_broadcast_sse) _active_runtime = runtime try: while True: data = await ws.receive_text() msg = json.loads(data) if msg.get("type") == "action": await runtime.handle_action(msg.get("action", "unknown"), msg.get("data")) elif msg.get("type") == "cancel_process": runtime.process_manager.cancel(msg.get("pid", 0)) else: await runtime.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard")) except WebSocketDisconnect: runtime.sensor.stop() if _active_runtime is runtime: _active_runtime = None @app.get("/api/events") async def sse_events(user=Depends(require_auth)): q: Queue = Queue(maxsize=100) _sse_subscribers.append(q) async def generate(): try: while True: event = await q.get() yield f"data: {json.dumps(event)}\n\n" except asyncio.CancelledError: pass finally: _sse_subscribers.remove(q) return StreamingResponse(generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) @app.get("/api/poll") async def poll(since: str = "", user=Depends(require_auth)): h = _state_hash() if since and since == h: return {"changed": False, "hash": h} return { "changed": True, "hash": h, "state": _active_runtime.memorizer.state if _active_runtime else None, "history_len": len(_active_runtime.history) if _active_runtime else 0, "last_messages": _active_runtime.history[-3:] if _active_runtime else [], } @app.post("/api/send/check") async def api_send_check(user=Depends(require_auth)): """Validate runtime is ready to accept a message. Fast, no LLM calls.""" global _pipeline_task if not _active_runtime: return {"ready": False, "reason": "no_session", "detail": "No WS connection -- someone must be connected via browser first"} if _pipeline_task and not _pipeline_task.done(): return {"ready": False, "reason": "busy", "detail": "Pipeline already running"} return { "ready": True, "graph": _active_runtime.graph.get("name", "unknown"), "identity": _active_runtime.identity, "history_len": len(_active_runtime.history), } @app.post("/api/send") async def api_send(body: dict, user=Depends(require_auth)): """Queue a message for async processing. Returns immediately with a message ID.""" global _pipeline_task, _pipeline_result, _pipeline_id if not _active_runtime: raise HTTPException(status_code=409, detail="No active session -- someone must be connected via WS first") if _pipeline_task and not _pipeline_task.done(): raise HTTPException(status_code=409, detail="Pipeline already running") text = body.get("text", "").strip() if not text: raise HTTPException(status_code=400, detail="Missing 'text' field") _pipeline_id += 1 msg_id = f"msg_{_pipeline_id}" dashboard = body.get("dashboard") _pipeline_result = {"status": "running", "id": msg_id, "stage": "queued", "text": text} async def _run_pipeline(): global _pipeline_result try: _pipeline_result["stage"] = "input" await _active_runtime.handle_message(text, dashboard=dashboard) _pipeline_result = { "status": "done", "id": msg_id, "stage": "done", "response": _active_runtime.history[-1]["content"] if _active_runtime.history else "", "memorizer": _active_runtime.memorizer.state, } except Exception as e: log.error(f"[api] pipeline error: {e}") _pipeline_result = { "status": "error", "id": msg_id, "stage": "error", "detail": str(e), } _pipeline_task = asyncio.create_task(_run_pipeline()) return {"status": "queued", "id": msg_id} @app.get("/api/result") async def api_result(user=Depends(require_auth)): """Poll for the current pipeline result.""" return _pipeline_result @app.post("/api/clear") async def api_clear(user=Depends(require_auth)): if not _active_runtime: raise HTTPException(status_code=409, detail="No active session") _active_runtime.history.clear() _active_runtime.ui_node.state.clear() _active_runtime.ui_node.bindings.clear() _active_runtime.ui_node.current_controls.clear() return {"status": "cleared"} @app.get("/api/state") async def get_state(user=Depends(require_auth)): if not _active_runtime: return {"status": "no_session"} return { "status": "active", "memorizer": _active_runtime.memorizer.state, "history_len": len(_active_runtime.history), } @app.get("/api/history") async def get_history(last: int = 10, user=Depends(require_auth)): if not _active_runtime: return {"status": "no_session", "messages": []} return { "status": "active", "messages": _active_runtime.history[-last:], } @app.get("/api/graph/active") async def get_active_graph(): from .engine import load_graph, get_graph_for_cytoscape from .runtime import _active_graph_name graph = load_graph(_active_graph_name) return { "name": graph["name"], "description": graph["description"], "nodes": graph["nodes"], "edges": graph["edges"], "cytoscape": get_graph_for_cytoscape(graph), } @app.get("/api/graph/list") async def get_graph_list(): from .engine import list_graphs return {"graphs": list_graphs()} @app.post("/api/graph/switch") async def switch_graph(body: dict, user=Depends(require_auth)): from .engine import load_graph import agent.runtime as rt name = body.get("name", "") graph = load_graph(name) # validates it exists rt._active_graph_name = name return {"status": "ok", "name": graph["name"], "note": "New sessions will use this graph. Existing session unchanged."} # --- Test status (real-time) --- _test_status = {"running": False, "current": "", "results": [], "last_green": None, "last_red": None} @app.post("/api/test/status") async def post_test_status(body: dict, user=Depends(require_auth)): """Receive test status updates from the test runner.""" event = body.get("event", "") if event == "suite_start": _test_status["running"] = True _test_status["current"] = body.get("suite", "") _test_status["results"] = [] elif event == "step_result": result = body.get("result", {}) _test_status["results"].append(result) _test_status["current"] = f"{result.get('step', '')} — {result.get('check', '')}" if result.get("status") == "FAIL": _test_status["last_red"] = result elif result.get("status") == "PASS": _test_status["last_green"] = result elif event == "suite_end": _test_status["running"] = False _test_status["current"] = "" # Broadcast to frontend via SSE + WS _broadcast_sse({"type": "test_status", **_test_status}) if _active_runtime: try: await _active_runtime.ws.send_text(json.dumps({"type": "test_status", **_test_status})) except Exception: pass return {"ok": True} @app.get("/api/test/status") async def get_test_status(user=Depends(require_auth)): return _test_status @app.get("/api/tests") async def get_tests(): """Latest test results from runtime_test.py.""" results_path = Path(__file__).parent.parent / "testcases" / "results.json" if not results_path.exists(): return {} return json.loads(results_path.read_text(encoding="utf-8")) @app.get("/api/trace") async def get_trace(last: int = 30, user=Depends(require_auth)): if not TRACE_FILE.exists(): return {"lines": []} lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n") parsed = [] for line in lines[-last:]: try: parsed.append(json.loads(line)) except json.JSONDecodeError: pass return {"lines": parsed}