Architecture: - Graph engine (engine.py) loads graph definitions, instantiates nodes - Versioned nodes: input_v1, thinker_v1, output_v1, memorizer_v1, director_v1 - NODE_REGISTRY for dynamic node lookup by name - Graph API: /api/graph/active, /api/graph/list, /api/graph/switch - Graph definition: graphs/v1_current.py (7 nodes, 13 edges, 3 edge types) S3* Audit system: - Workspace mismatch detection (server vs browser controls) - Code-without-tools retry (Thinker wrote code but no tool calls) - Intent-without-action retry (request intent but Thinker only produced text) - Dashboard feedback: browser sends workspace state on every message - Sensor continuous comparison on 5s tick State machines: - create_machine / add_state / reset_machine / destroy_machine via function calling - Local transitions (go:) resolve without LLM round-trip - Button persistence across turns Database tools: - query_db tool via pymysql to MariaDB K3s pod (eras2_production) - Table rendering in workspace (tab-separated parsing) - Director pre-planning with Opus for complex data requests - Error retry with corrected SQL Frontend: - Cytoscape.js pipeline graph with real-time node animations - Overlay scrollbars (CSS-only, no reflow) - Tool call/result trace events - S3* audit events in trace Testing: - 167 integration tests (11 test suites) - 22 node-level unit tests (test_nodes/) - Three test levels: node unit, graph integration, scenario Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
226 lines
8.3 KiB
Python
226 lines
8.3 KiB
Python
"""API endpoints, SSE, polling."""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from asyncio import Queue
|
|
from pathlib import Path
|
|
|
|
from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
|
|
from starlette.responses import StreamingResponse
|
|
|
|
import httpx
|
|
|
|
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
|
from .runtime import Runtime, TRACE_FILE
|
|
|
|
log = logging.getLogger("runtime")
|
|
|
|
# Active runtime reference (set by WS endpoint)
|
|
_active_runtime: Runtime | None = None
|
|
|
|
# SSE subscribers
|
|
_sse_subscribers: list[Queue] = []
|
|
|
|
|
|
def _broadcast_sse(event: dict):
|
|
"""Push an event to all SSE subscribers."""
|
|
for q in _sse_subscribers:
|
|
try:
|
|
q.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
|
|
|
|
def _state_hash() -> str:
|
|
if not _active_runtime:
|
|
return "no_session"
|
|
raw = json.dumps({
|
|
"mem": _active_runtime.memorizer.state,
|
|
"hlen": len(_active_runtime.history),
|
|
}, sort_keys=True)
|
|
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
|
|
|
|
|
def register_routes(app):
|
|
"""Register all API routes on the FastAPI app."""
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/auth/config")
|
|
async def auth_config():
|
|
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
|
|
return {
|
|
"enabled": AUTH_ENABLED,
|
|
"issuer": ZITADEL_ISSUER,
|
|
"clientId": ZITADEL_CLIENT_ID,
|
|
"projectId": ZITADEL_PROJECT_ID,
|
|
}
|
|
|
|
@app.websocket("/ws")
|
|
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
|
|
access_token: str | None = Query(None)):
|
|
global _active_runtime
|
|
user_claims = {"sub": "anonymous"}
|
|
if AUTH_ENABLED and token:
|
|
try:
|
|
user_claims = await _validate_token(token)
|
|
if not user_claims.get("name") and access_token:
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
|
|
headers={"Authorization": f"Bearer {access_token}"})
|
|
if resp.status_code == 200:
|
|
info = resp.json()
|
|
log.info(f"[auth] userinfo enrichment: {info}")
|
|
user_claims["name"] = info.get("name")
|
|
user_claims["preferred_username"] = info.get("preferred_username")
|
|
user_claims["email"] = info.get("email")
|
|
except HTTPException:
|
|
await ws.close(code=4001, reason="Invalid token")
|
|
return
|
|
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
|
await ws.accept()
|
|
runtime = Runtime(ws, user_claims=user_claims, origin=origin, broadcast=_broadcast_sse)
|
|
_active_runtime = runtime
|
|
try:
|
|
while True:
|
|
data = await ws.receive_text()
|
|
msg = json.loads(data)
|
|
if msg.get("type") == "action":
|
|
await runtime.handle_action(msg.get("action", "unknown"), msg.get("data"))
|
|
elif msg.get("type") == "cancel_process":
|
|
runtime.process_manager.cancel(msg.get("pid", 0))
|
|
else:
|
|
await runtime.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard"))
|
|
except WebSocketDisconnect:
|
|
runtime.sensor.stop()
|
|
if _active_runtime is runtime:
|
|
_active_runtime = None
|
|
|
|
@app.get("/api/events")
|
|
async def sse_events(user=Depends(require_auth)):
|
|
q: Queue = Queue(maxsize=100)
|
|
_sse_subscribers.append(q)
|
|
|
|
async def generate():
|
|
try:
|
|
while True:
|
|
event = await q.get()
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_sse_subscribers.remove(q)
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
@app.get("/api/poll")
|
|
async def poll(since: str = "", user=Depends(require_auth)):
|
|
h = _state_hash()
|
|
if since and since == h:
|
|
return {"changed": False, "hash": h}
|
|
return {
|
|
"changed": True,
|
|
"hash": h,
|
|
"state": _active_runtime.memorizer.state if _active_runtime else None,
|
|
"history_len": len(_active_runtime.history) if _active_runtime else 0,
|
|
"last_messages": _active_runtime.history[-3:] if _active_runtime else [],
|
|
}
|
|
|
|
@app.post("/api/send")
|
|
async def api_send(body: dict, user=Depends(require_auth)):
|
|
if not _active_runtime:
|
|
raise HTTPException(status_code=409, detail="No active session -- someone must be connected via WS first")
|
|
text = body.get("text", "").strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="Missing 'text' field")
|
|
dashboard = body.get("dashboard")
|
|
await _active_runtime.handle_message(text, dashboard=dashboard)
|
|
return {
|
|
"status": "ok",
|
|
"response": _active_runtime.history[-1]["content"] if _active_runtime.history else "",
|
|
"memorizer": _active_runtime.memorizer.state,
|
|
}
|
|
|
|
@app.post("/api/clear")
|
|
async def api_clear(user=Depends(require_auth)):
|
|
if not _active_runtime:
|
|
raise HTTPException(status_code=409, detail="No active session")
|
|
_active_runtime.history.clear()
|
|
_active_runtime.ui_node.state.clear()
|
|
_active_runtime.ui_node.bindings.clear()
|
|
_active_runtime.ui_node.current_controls.clear()
|
|
return {"status": "cleared"}
|
|
|
|
@app.get("/api/state")
|
|
async def get_state(user=Depends(require_auth)):
|
|
if not _active_runtime:
|
|
return {"status": "no_session"}
|
|
return {
|
|
"status": "active",
|
|
"memorizer": _active_runtime.memorizer.state,
|
|
"history_len": len(_active_runtime.history),
|
|
}
|
|
|
|
@app.get("/api/history")
|
|
async def get_history(last: int = 10, user=Depends(require_auth)):
|
|
if not _active_runtime:
|
|
return {"status": "no_session", "messages": []}
|
|
return {
|
|
"status": "active",
|
|
"messages": _active_runtime.history[-last:],
|
|
}
|
|
|
|
@app.get("/api/graph/active")
|
|
async def get_active_graph():
|
|
from .engine import load_graph, get_graph_for_cytoscape
|
|
from .runtime import _active_graph_name
|
|
graph = load_graph(_active_graph_name)
|
|
return {
|
|
"name": graph["name"],
|
|
"description": graph["description"],
|
|
"nodes": graph["nodes"],
|
|
"edges": graph["edges"],
|
|
"cytoscape": get_graph_for_cytoscape(graph),
|
|
}
|
|
|
|
@app.get("/api/graph/list")
|
|
async def get_graph_list():
|
|
from .engine import list_graphs
|
|
return {"graphs": list_graphs()}
|
|
|
|
@app.post("/api/graph/switch")
|
|
async def switch_graph(body: dict, user=Depends(require_auth)):
|
|
from .engine import load_graph
|
|
import agent.runtime as rt
|
|
name = body.get("name", "")
|
|
graph = load_graph(name) # validates it exists
|
|
rt._active_graph_name = name
|
|
return {"status": "ok", "name": graph["name"],
|
|
"note": "New sessions will use this graph. Existing session unchanged."}
|
|
|
|
@app.get("/api/tests")
|
|
async def get_tests():
|
|
"""Latest test results from runtime_test.py."""
|
|
results_path = Path(__file__).parent.parent / "testcases" / "results.json"
|
|
if not results_path.exists():
|
|
return {}
|
|
return json.loads(results_path.read_text(encoding="utf-8"))
|
|
|
|
@app.get("/api/trace")
|
|
async def get_trace(last: int = 30, user=Depends(require_auth)):
|
|
if not TRACE_FILE.exists():
|
|
return {"lines": []}
|
|
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
|
|
parsed = []
|
|
for line in lines[-last:]:
|
|
try:
|
|
parsed.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"lines": parsed}
|