- Harness reports to /api/test/status with suite_start/step_result/suite_end - Frontend shows x/44 progress, per-test duration, total elapsed time - Auto-discovers test count from test modules (no hardcoded number) - run_all.py --report URL pushes live results to browser - Fix: suite_start with count only resets on first call Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
340 lines
13 KiB
Python
340 lines
13 KiB
Python
"""API endpoints, SSE, polling."""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from asyncio import Queue
|
|
from pathlib import Path
|
|
|
|
from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
|
|
from starlette.responses import StreamingResponse
|
|
|
|
import httpx
|
|
|
|
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
|
from .runtime import Runtime, TRACE_FILE
|
|
|
|
log = logging.getLogger("runtime")
|
|
|
|
# Active runtime reference (set by WS endpoint)
|
|
_active_runtime: Runtime | None = None
|
|
|
|
# SSE subscribers
|
|
_sse_subscribers: list[Queue] = []
|
|
|
|
# Async message pipeline state
|
|
_pipeline_task: asyncio.Task | None = None
|
|
_pipeline_result: dict = {"status": "idle"}
|
|
_pipeline_id: int = 0
|
|
|
|
|
|
def _broadcast_sse(event: dict):
|
|
"""Push an event to all SSE subscribers + update pipeline progress."""
|
|
for q in _sse_subscribers:
|
|
try:
|
|
q.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
# Update pipeline progress from HUD events
|
|
if _pipeline_result.get("status") == "running":
|
|
node = event.get("node", "")
|
|
evt = event.get("event", "")
|
|
if node and evt in ("thinking", "perceived", "decided", "streaming", "tool_exec", "interpreted", "updated"):
|
|
_pipeline_result["stage"] = node
|
|
_pipeline_result["event"] = evt
|
|
|
|
|
|
def _state_hash() -> str:
|
|
if not _active_runtime:
|
|
return "no_session"
|
|
raw = json.dumps({
|
|
"mem": _active_runtime.memorizer.state,
|
|
"hlen": len(_active_runtime.history),
|
|
}, sort_keys=True)
|
|
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
|
|
|
|
|
def register_routes(app):
|
|
"""Register all API routes on the FastAPI app."""
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/auth/config")
|
|
async def auth_config():
|
|
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
|
|
return {
|
|
"enabled": AUTH_ENABLED,
|
|
"issuer": ZITADEL_ISSUER,
|
|
"clientId": ZITADEL_CLIENT_ID,
|
|
"projectId": ZITADEL_PROJECT_ID,
|
|
}
|
|
|
|
def _ensure_runtime(user_claims=None, origin=""):
|
|
"""Get or create the persistent runtime."""
|
|
global _active_runtime
|
|
if _active_runtime is None:
|
|
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
|
broadcast=_broadcast_sse)
|
|
log.info("[api] created persistent runtime")
|
|
return _active_runtime
|
|
|
|
@app.websocket("/ws")
|
|
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
|
|
access_token: str | None = Query(None)):
|
|
user_claims = {"sub": "anonymous"}
|
|
if AUTH_ENABLED and token:
|
|
try:
|
|
user_claims = await _validate_token(token)
|
|
if not user_claims.get("name") and access_token:
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
|
|
headers={"Authorization": f"Bearer {access_token}"})
|
|
if resp.status_code == 200:
|
|
info = resp.json()
|
|
log.info(f"[auth] userinfo enrichment: {info}")
|
|
user_claims["name"] = info.get("name")
|
|
user_claims["preferred_username"] = info.get("preferred_username")
|
|
user_claims["email"] = info.get("email")
|
|
except HTTPException:
|
|
await ws.close(code=4001, reason="Invalid token")
|
|
return
|
|
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
|
await ws.accept()
|
|
|
|
# Get or create runtime, attach WS
|
|
runtime = _ensure_runtime(user_claims=user_claims, origin=origin)
|
|
runtime.update_identity(user_claims, origin)
|
|
runtime.attach_ws(ws)
|
|
|
|
try:
|
|
while True:
|
|
data = await ws.receive_text()
|
|
msg = json.loads(data)
|
|
if msg.get("type") == "action":
|
|
await runtime.handle_action(msg.get("action", "unknown"), msg.get("data"))
|
|
elif msg.get("type") == "cancel_process":
|
|
runtime.process_manager.cancel(msg.get("pid", 0))
|
|
else:
|
|
await runtime.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard"))
|
|
except WebSocketDisconnect:
|
|
runtime.detach_ws()
|
|
log.info("[api] WS disconnected — runtime stays alive")
|
|
|
|
@app.get("/api/events")
|
|
async def sse_events(user=Depends(require_auth)):
|
|
q: Queue = Queue(maxsize=100)
|
|
_sse_subscribers.append(q)
|
|
|
|
async def generate():
|
|
try:
|
|
while True:
|
|
event = await q.get()
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_sse_subscribers.remove(q)
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
@app.get("/api/poll")
|
|
async def poll(since: str = "", user=Depends(require_auth)):
|
|
h = _state_hash()
|
|
if since and since == h:
|
|
return {"changed": False, "hash": h}
|
|
return {
|
|
"changed": True,
|
|
"hash": h,
|
|
"state": _active_runtime.memorizer.state if _active_runtime else {},
|
|
"history_len": len(_active_runtime.history) if _active_runtime else 0,
|
|
"last_messages": _active_runtime.history[-3:] if _active_runtime else [],
|
|
}
|
|
|
|
@app.post("/api/send/check")
|
|
async def api_send_check(user=Depends(require_auth)):
|
|
"""Validate runtime is ready to accept a message. Fast, no LLM calls."""
|
|
global _pipeline_task
|
|
runtime = _ensure_runtime()
|
|
if _pipeline_task and not _pipeline_task.done():
|
|
return {"ready": False, "reason": "busy", "detail": "Pipeline already running"}
|
|
return {
|
|
"ready": True,
|
|
"graph": runtime.graph.get("name", "unknown"),
|
|
"identity": runtime.identity,
|
|
"history_len": len(runtime.history),
|
|
"ws_connected": runtime.sink.ws is not None,
|
|
}
|
|
|
|
@app.post("/api/send")
|
|
async def api_send(body: dict, user=Depends(require_auth)):
|
|
"""Queue a message for async processing. Returns immediately with a message ID."""
|
|
global _pipeline_task, _pipeline_result, _pipeline_id
|
|
runtime = _ensure_runtime()
|
|
if _pipeline_task and not _pipeline_task.done():
|
|
raise HTTPException(status_code=409, detail="Pipeline already running")
|
|
text = body.get("text", "").strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="Missing 'text' field")
|
|
|
|
_pipeline_id += 1
|
|
msg_id = f"msg_{_pipeline_id}"
|
|
dashboard = body.get("dashboard")
|
|
|
|
_pipeline_result = {"status": "running", "id": msg_id, "stage": "queued", "text": text}
|
|
|
|
async def _run_pipeline():
|
|
global _pipeline_result
|
|
try:
|
|
_pipeline_result["stage"] = "input"
|
|
await runtime.handle_message(text, dashboard=dashboard)
|
|
_pipeline_result = {
|
|
"status": "done",
|
|
"id": msg_id,
|
|
"stage": "done",
|
|
"response": runtime.history[-1]["content"] if runtime.history else "",
|
|
"memorizer": runtime.memorizer.state,
|
|
}
|
|
except Exception as e:
|
|
log.error(f"[api] pipeline error: {e}")
|
|
_pipeline_result = {
|
|
"status": "error",
|
|
"id": msg_id,
|
|
"stage": "error",
|
|
"detail": str(e),
|
|
}
|
|
|
|
_pipeline_task = asyncio.create_task(_run_pipeline())
|
|
return {"status": "queued", "id": msg_id}
|
|
|
|
@app.get("/api/result")
|
|
async def api_result(user=Depends(require_auth)):
|
|
"""Poll for the current pipeline result."""
|
|
return _pipeline_result
|
|
|
|
@app.post("/api/clear")
|
|
async def api_clear(user=Depends(require_auth)):
|
|
runtime = _ensure_runtime()
|
|
runtime.history.clear()
|
|
runtime.ui_node.state.clear()
|
|
runtime.ui_node.bindings.clear()
|
|
runtime.ui_node.current_controls.clear()
|
|
runtime.ui_node.machines.clear()
|
|
return {"status": "cleared"}
|
|
|
|
@app.get("/api/state")
|
|
async def get_state(user=Depends(require_auth)):
|
|
runtime = _ensure_runtime()
|
|
return {
|
|
"status": "active",
|
|
"memorizer": runtime.memorizer.state,
|
|
"history_len": len(runtime.history),
|
|
"ws_connected": runtime.sink.ws is not None,
|
|
}
|
|
|
|
@app.get("/api/history")
|
|
async def get_history(last: int = 10, user=Depends(require_auth)):
|
|
runtime = _ensure_runtime()
|
|
return {
|
|
"status": "active",
|
|
"messages": runtime.history[-last:],
|
|
}
|
|
|
|
@app.get("/api/graph/active")
|
|
async def get_active_graph():
|
|
from .engine import load_graph, get_graph_for_cytoscape
|
|
from .runtime import _active_graph_name
|
|
graph = load_graph(_active_graph_name)
|
|
return {
|
|
"name": graph["name"],
|
|
"description": graph["description"],
|
|
"nodes": graph["nodes"],
|
|
"edges": graph["edges"],
|
|
"cytoscape": get_graph_for_cytoscape(graph),
|
|
}
|
|
|
|
@app.get("/api/graph/list")
|
|
async def get_graph_list():
|
|
from .engine import list_graphs
|
|
return {"graphs": list_graphs()}
|
|
|
|
@app.post("/api/graph/switch")
|
|
async def switch_graph(body: dict, user=Depends(require_auth)):
|
|
global _active_runtime
|
|
from .engine import load_graph
|
|
import agent.runtime as rt
|
|
name = body.get("name", "")
|
|
graph = load_graph(name) # validates it exists
|
|
rt._active_graph_name = name
|
|
# Kill old runtime, next request creates new one with new graph
|
|
if _active_runtime:
|
|
_active_runtime.sensor.stop()
|
|
_active_runtime = None
|
|
return {"status": "ok", "name": graph["name"],
|
|
"note": "New sessions will use this graph. Existing session unchanged."}
|
|
|
|
# --- Test status (real-time) ---
|
|
_test_status = {"running": False, "current": "", "results": [], "last_green": None, "last_red": None, "total_expected": 0}
|
|
|
|
@app.post("/api/test/status")
|
|
async def post_test_status(body: dict, user=Depends(require_auth)):
|
|
"""Receive test status updates from the test runner."""
|
|
event = body.get("event", "")
|
|
if event == "suite_start":
|
|
_test_status["running"] = True
|
|
_test_status["current"] = body.get("suite", "")
|
|
if body.get("count"):
|
|
# First suite_start with count resets everything
|
|
_test_status["results"] = []
|
|
_test_status["total_expected"] = body["count"]
|
|
_test_status["last_green"] = None
|
|
_test_status["last_red"] = None
|
|
elif event == "step_result":
|
|
result = body.get("result", {})
|
|
_test_status["results"].append(result)
|
|
_test_status["current"] = f"{result.get('step', '')} — {result.get('check', '')}"
|
|
if result.get("status") == "FAIL":
|
|
_test_status["last_red"] = result
|
|
elif result.get("status") == "PASS":
|
|
_test_status["last_green"] = result
|
|
elif event == "suite_end":
|
|
_test_status["running"] = False
|
|
_test_status["current"] = ""
|
|
# Broadcast to frontend via SSE + WS
|
|
_broadcast_sse({"type": "test_status", **_test_status})
|
|
runtime = _ensure_runtime()
|
|
if runtime.sink.ws:
|
|
try:
|
|
await runtime.sink.ws.send_text(json.dumps({"type": "test_status", **_test_status}))
|
|
except Exception:
|
|
pass
|
|
return {"ok": True}
|
|
|
|
@app.get("/api/test/status")
|
|
async def get_test_status(user=Depends(require_auth)):
|
|
return _test_status
|
|
|
|
@app.get("/api/tests")
|
|
async def get_tests():
|
|
"""Latest test results from runtime_test.py."""
|
|
results_path = Path(__file__).parent.parent / "testcases" / "results.json"
|
|
if not results_path.exists():
|
|
return {}
|
|
return json.loads(results_path.read_text(encoding="utf-8"))
|
|
|
|
@app.get("/api/trace")
|
|
async def get_trace(last: int = 30, user=Depends(require_auth)):
|
|
if not TRACE_FILE.exists():
|
|
return {"lines": []}
|
|
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
|
|
parsed = []
|
|
for line in lines[-last:]:
|
|
try:
|
|
parsed.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"lines": parsed}
|