Nico 4c412d3c4b v0.14.4: Interpreter wired in v2, tool_call convention, Haiku models, UI fix
- Wire Interpreter into v2 pipeline (after Thinker tool_output, before Output)
- Rename tool_exec -> tool_call everywhere (consistent convention across v1/v2)
- Switch Director v1+v2 to anthropic/claude-haiku-4.5 (was opus, reserved)
- Fix UI apply_machine_ops crash when states are strings instead of dicts
- Fix runtime_test.py async poll to match on message ID (prevent stale results)
- Add traceback to pipeline error logging

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-29 06:06:13 +02:00

341 lines
13 KiB
Python

"""API endpoints, SSE, polling."""
import asyncio
import hashlib
import json
import logging
from asyncio import Queue
from pathlib import Path
from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
from starlette.responses import StreamingResponse
import httpx
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
from .runtime import Runtime, TRACE_FILE
log = logging.getLogger("runtime")
# Active runtime reference (set by WS endpoint)
_active_runtime: Runtime | None = None
# SSE subscribers
_sse_subscribers: list[Queue] = []
# Async message pipeline state
_pipeline_task: asyncio.Task | None = None
_pipeline_result: dict = {"status": "idle"}
_pipeline_id: int = 0
def _broadcast_sse(event: dict):
"""Push an event to all SSE subscribers + update pipeline progress."""
for q in _sse_subscribers:
try:
q.put_nowait(event)
except asyncio.QueueFull:
pass
# Update pipeline progress from HUD events
if _pipeline_result.get("status") == "running":
node = event.get("node", "")
evt = event.get("event", "")
if node and evt in ("thinking", "perceived", "decided", "streaming", "tool_call", "interpreted", "updated"):
_pipeline_result["stage"] = node
_pipeline_result["event"] = evt
def _state_hash() -> str:
if not _active_runtime:
return "no_session"
raw = json.dumps({
"mem": _active_runtime.memorizer.state,
"hlen": len(_active_runtime.history),
}, sort_keys=True)
return hashlib.md5(raw.encode()).hexdigest()[:12]
def register_routes(app):
"""Register all API routes on the FastAPI app."""
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/auth/config")
async def auth_config():
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
return {
"enabled": AUTH_ENABLED,
"issuer": ZITADEL_ISSUER,
"clientId": ZITADEL_CLIENT_ID,
"projectId": ZITADEL_PROJECT_ID,
}
def _ensure_runtime(user_claims=None, origin=""):
"""Get or create the persistent runtime."""
global _active_runtime
if _active_runtime is None:
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
broadcast=_broadcast_sse)
log.info("[api] created persistent runtime")
return _active_runtime
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
access_token: str | None = Query(None)):
user_claims = {"sub": "anonymous"}
if AUTH_ENABLED and token:
try:
user_claims = await _validate_token(token)
if not user_claims.get("name") and access_token:
async with httpx.AsyncClient() as client:
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
headers={"Authorization": f"Bearer {access_token}"})
if resp.status_code == 200:
info = resp.json()
log.info(f"[auth] userinfo enrichment: {info}")
user_claims["name"] = info.get("name")
user_claims["preferred_username"] = info.get("preferred_username")
user_claims["email"] = info.get("email")
except HTTPException:
await ws.close(code=4001, reason="Invalid token")
return
origin = ws.headers.get("origin", ws.headers.get("host", ""))
await ws.accept()
# Get or create runtime, attach WS
runtime = _ensure_runtime(user_claims=user_claims, origin=origin)
runtime.update_identity(user_claims, origin)
runtime.attach_ws(ws)
try:
while True:
data = await ws.receive_text()
msg = json.loads(data)
if msg.get("type") == "action":
await runtime.handle_action(msg.get("action", "unknown"), msg.get("data"))
elif msg.get("type") == "cancel_process":
runtime.process_manager.cancel(msg.get("pid", 0))
else:
await runtime.handle_message(msg.get("text", ""), dashboard=msg.get("dashboard"))
except WebSocketDisconnect:
runtime.detach_ws()
log.info("[api] WS disconnected — runtime stays alive")
@app.get("/api/events")
async def sse_events(user=Depends(require_auth)):
q: Queue = Queue(maxsize=100)
_sse_subscribers.append(q)
async def generate():
try:
while True:
event = await q.get()
yield f"data: {json.dumps(event)}\n\n"
except asyncio.CancelledError:
pass
finally:
_sse_subscribers.remove(q)
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@app.get("/api/poll")
async def poll(since: str = "", user=Depends(require_auth)):
h = _state_hash()
if since and since == h:
return {"changed": False, "hash": h}
return {
"changed": True,
"hash": h,
"state": _active_runtime.memorizer.state if _active_runtime else {},
"history_len": len(_active_runtime.history) if _active_runtime else 0,
"last_messages": _active_runtime.history[-3:] if _active_runtime else [],
}
@app.post("/api/send/check")
async def api_send_check(user=Depends(require_auth)):
"""Validate runtime is ready to accept a message. Fast, no LLM calls."""
global _pipeline_task
runtime = _ensure_runtime()
if _pipeline_task and not _pipeline_task.done():
return {"ready": False, "reason": "busy", "detail": "Pipeline already running"}
return {
"ready": True,
"graph": runtime.graph.get("name", "unknown"),
"identity": runtime.identity,
"history_len": len(runtime.history),
"ws_connected": runtime.sink.ws is not None,
}
@app.post("/api/send")
async def api_send(body: dict, user=Depends(require_auth)):
"""Queue a message for async processing. Returns immediately with a message ID."""
global _pipeline_task, _pipeline_result, _pipeline_id
runtime = _ensure_runtime()
if _pipeline_task and not _pipeline_task.done():
raise HTTPException(status_code=409, detail="Pipeline already running")
text = body.get("text", "").strip()
if not text:
raise HTTPException(status_code=400, detail="Missing 'text' field")
_pipeline_id += 1
msg_id = f"msg_{_pipeline_id}"
dashboard = body.get("dashboard")
_pipeline_result = {"status": "running", "id": msg_id, "stage": "queued", "text": text}
async def _run_pipeline():
global _pipeline_result
try:
_pipeline_result["stage"] = "input"
await runtime.handle_message(text, dashboard=dashboard)
_pipeline_result = {
"status": "done",
"id": msg_id,
"stage": "done",
"response": runtime.history[-1]["content"] if runtime.history else "",
"memorizer": runtime.memorizer.state,
}
except Exception as e:
import traceback
log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}")
_pipeline_result = {
"status": "error",
"id": msg_id,
"stage": "error",
"detail": str(e),
}
_pipeline_task = asyncio.create_task(_run_pipeline())
return {"status": "queued", "id": msg_id}
@app.get("/api/result")
async def api_result(user=Depends(require_auth)):
"""Poll for the current pipeline result."""
return _pipeline_result
@app.post("/api/clear")
async def api_clear(user=Depends(require_auth)):
runtime = _ensure_runtime()
runtime.history.clear()
runtime.ui_node.state.clear()
runtime.ui_node.bindings.clear()
runtime.ui_node.current_controls.clear()
runtime.ui_node.machines.clear()
return {"status": "cleared"}
@app.get("/api/state")
async def get_state(user=Depends(require_auth)):
runtime = _ensure_runtime()
return {
"status": "active",
"memorizer": runtime.memorizer.state,
"history_len": len(runtime.history),
"ws_connected": runtime.sink.ws is not None,
}
@app.get("/api/history")
async def get_history(last: int = 10, user=Depends(require_auth)):
runtime = _ensure_runtime()
return {
"status": "active",
"messages": runtime.history[-last:],
}
@app.get("/api/graph/active")
async def get_active_graph():
from .engine import load_graph, get_graph_for_cytoscape
from .runtime import _active_graph_name
graph = load_graph(_active_graph_name)
return {
"name": graph["name"],
"description": graph["description"],
"nodes": graph["nodes"],
"edges": graph["edges"],
"cytoscape": get_graph_for_cytoscape(graph),
}
@app.get("/api/graph/list")
async def get_graph_list():
from .engine import list_graphs
return {"graphs": list_graphs()}
@app.post("/api/graph/switch")
async def switch_graph(body: dict, user=Depends(require_auth)):
global _active_runtime
from .engine import load_graph
import agent.runtime as rt
name = body.get("name", "")
graph = load_graph(name) # validates it exists
rt._active_graph_name = name
# Kill old runtime, next request creates new one with new graph
if _active_runtime:
_active_runtime.sensor.stop()
_active_runtime = None
return {"status": "ok", "name": graph["name"],
"note": "New sessions will use this graph. Existing session unchanged."}
# --- Test status (real-time) ---
_test_status = {"running": False, "current": "", "results": [], "last_green": None, "last_red": None, "total_expected": 0}
@app.post("/api/test/status")
async def post_test_status(body: dict, user=Depends(require_auth)):
"""Receive test status updates from the test runner."""
event = body.get("event", "")
if event == "suite_start":
_test_status["running"] = True
_test_status["current"] = body.get("suite", "")
if body.get("count"):
# First suite_start with count resets everything
_test_status["results"] = []
_test_status["total_expected"] = body["count"]
_test_status["last_green"] = None
_test_status["last_red"] = None
elif event == "step_result":
result = body.get("result", {})
_test_status["results"].append(result)
_test_status["current"] = f"{result.get('step', '')}{result.get('check', '')}"
if result.get("status") == "FAIL":
_test_status["last_red"] = result
elif result.get("status") == "PASS":
_test_status["last_green"] = result
elif event == "suite_end":
_test_status["running"] = False
_test_status["current"] = ""
# Broadcast to frontend via SSE + WS
_broadcast_sse({"type": "test_status", **_test_status})
runtime = _ensure_runtime()
if runtime.sink.ws:
try:
await runtime.sink.ws.send_text(json.dumps({"type": "test_status", **_test_status}))
except Exception:
pass
return {"ok": True}
@app.get("/api/test/status")
async def get_test_status(user=Depends(require_auth)):
return _test_status
@app.get("/api/tests")
async def get_tests():
"""Latest test results from runtime_test.py."""
results_path = Path(__file__).parent.parent / "testcases" / "results.json"
if not results_path.exists():
return {}
return json.loads(results_path.read_text(encoding="utf-8"))
@app.get("/api/trace")
async def get_trace(last: int = 30, user=Depends(require_auth)):
if not TRACE_FILE.exists():
return {"lines": []}
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
parsed = []
for line in lines[-last:]:
try:
parsed.append(json.loads(line))
except json.JSONDecodeError:
pass
return {"lines": parsed}