- Thinker tool results stream directly to user, skipping Output node (halves latency) - ProcessManager process_start/process_done events render as live cards in chat - UI controls sent before response text, not after - Button clicks route to handle_action(), skip Input, go straight to Thinker - Fix Thinker model: gemini-2.5-flash-preview -> gemini-2.5-flash (old ID expired) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
186 lines
6.7 KiB
Python
186 lines
6.7 KiB
Python
"""API endpoints, SSE, polling."""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from asyncio import Queue
|
|
from pathlib import Path
|
|
|
|
from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
|
|
from starlette.responses import StreamingResponse
|
|
|
|
import httpx
|
|
|
|
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
|
from .runtime import Runtime, TRACE_FILE
|
|
|
|
log = logging.getLogger("runtime")
|
|
|
|
# Active runtime reference (set by WS endpoint)
|
|
_active_runtime: Runtime | None = None
|
|
|
|
# SSE subscribers
|
|
_sse_subscribers: list[Queue] = []
|
|
|
|
|
|
def _broadcast_sse(event: dict):
|
|
"""Push an event to all SSE subscribers."""
|
|
for q in _sse_subscribers:
|
|
try:
|
|
q.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
|
|
|
|
def _state_hash() -> str:
|
|
if not _active_runtime:
|
|
return "no_session"
|
|
raw = json.dumps({
|
|
"mem": _active_runtime.memorizer.state,
|
|
"hlen": len(_active_runtime.history),
|
|
}, sort_keys=True)
|
|
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
|
|
|
|
|
def register_routes(app):
|
|
"""Register all API routes on the FastAPI app."""
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/auth/config")
|
|
async def auth_config():
|
|
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
|
|
return {
|
|
"enabled": AUTH_ENABLED,
|
|
"issuer": ZITADEL_ISSUER,
|
|
"clientId": ZITADEL_CLIENT_ID,
|
|
"projectId": ZITADEL_PROJECT_ID,
|
|
}
|
|
|
|
@app.websocket("/ws")
|
|
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
|
|
access_token: str | None = Query(None)):
|
|
global _active_runtime
|
|
user_claims = {"sub": "anonymous"}
|
|
if AUTH_ENABLED and token:
|
|
try:
|
|
user_claims = await _validate_token(token)
|
|
if not user_claims.get("name") and access_token:
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
|
|
headers={"Authorization": f"Bearer {access_token}"})
|
|
if resp.status_code == 200:
|
|
info = resp.json()
|
|
log.info(f"[auth] userinfo enrichment: {info}")
|
|
user_claims["name"] = info.get("name")
|
|
user_claims["preferred_username"] = info.get("preferred_username")
|
|
user_claims["email"] = info.get("email")
|
|
except HTTPException:
|
|
await ws.close(code=4001, reason="Invalid token")
|
|
return
|
|
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
|
await ws.accept()
|
|
runtime = Runtime(ws, user_claims=user_claims, origin=origin, broadcast=_broadcast_sse)
|
|
_active_runtime = runtime
|
|
try:
|
|
while True:
|
|
data = await ws.receive_text()
|
|
msg = json.loads(data)
|
|
if msg.get("type") == "action":
|
|
await runtime.handle_action(msg.get("action", "unknown"), msg.get("data"))
|
|
elif msg.get("type") == "cancel_process":
|
|
runtime.process_manager.cancel(msg.get("pid", 0))
|
|
else:
|
|
await runtime.handle_message(msg.get("text", ""))
|
|
except WebSocketDisconnect:
|
|
runtime.sensor.stop()
|
|
if _active_runtime is runtime:
|
|
_active_runtime = None
|
|
|
|
@app.get("/api/events")
|
|
async def sse_events(user=Depends(require_auth)):
|
|
q: Queue = Queue(maxsize=100)
|
|
_sse_subscribers.append(q)
|
|
|
|
async def generate():
|
|
try:
|
|
while True:
|
|
event = await q.get()
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_sse_subscribers.remove(q)
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
@app.get("/api/poll")
|
|
async def poll(since: str = "", user=Depends(require_auth)):
|
|
h = _state_hash()
|
|
if since and since == h:
|
|
return {"changed": False, "hash": h}
|
|
return {
|
|
"changed": True,
|
|
"hash": h,
|
|
"state": _active_runtime.memorizer.state if _active_runtime else None,
|
|
"history_len": len(_active_runtime.history) if _active_runtime else 0,
|
|
"last_messages": _active_runtime.history[-3:] if _active_runtime else [],
|
|
}
|
|
|
|
@app.post("/api/send")
|
|
async def api_send(body: dict, user=Depends(require_auth)):
|
|
if not _active_runtime:
|
|
raise HTTPException(status_code=409, detail="No active session -- someone must be connected via WS first")
|
|
text = body.get("text", "").strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="Missing 'text' field")
|
|
await _active_runtime.handle_message(text)
|
|
return {
|
|
"status": "ok",
|
|
"response": _active_runtime.history[-1]["content"] if _active_runtime.history else "",
|
|
"memorizer": _active_runtime.memorizer.state,
|
|
}
|
|
|
|
@app.post("/api/clear")
|
|
async def api_clear(user=Depends(require_auth)):
|
|
if not _active_runtime:
|
|
raise HTTPException(status_code=409, detail="No active session")
|
|
_active_runtime.history.clear()
|
|
return {"status": "cleared"}
|
|
|
|
@app.get("/api/state")
|
|
async def get_state(user=Depends(require_auth)):
|
|
if not _active_runtime:
|
|
return {"status": "no_session"}
|
|
return {
|
|
"status": "active",
|
|
"memorizer": _active_runtime.memorizer.state,
|
|
"history_len": len(_active_runtime.history),
|
|
}
|
|
|
|
@app.get("/api/history")
|
|
async def get_history(last: int = 10, user=Depends(require_auth)):
|
|
if not _active_runtime:
|
|
return {"status": "no_session", "messages": []}
|
|
return {
|
|
"status": "active",
|
|
"messages": _active_runtime.history[-last:],
|
|
}
|
|
|
|
@app.get("/api/trace")
|
|
async def get_trace(last: int = 30, user=Depends(require_auth)):
|
|
if not TRACE_FILE.exists():
|
|
return {"lines": []}
|
|
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
|
|
parsed = []
|
|
for line in lines[-last:]:
|
|
try:
|
|
parsed.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"lines": parsed}
|