agent-runtime/agent.py
Nico 5c7aece397 v0.5.5: node token meters in frontend
- Per-node context fill bars (input/output/memorizer/sensor)
- Color-coded: green <50%, amber 50-80%, red >80%
- Sensor meter shows tick count + latest deltas
- Token info in trace context events

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-28 00:51:43 +01:00

848 lines
32 KiB
Python

"""
Cognitive Agent Runtime — Phase A.2: Three-node graph (Input → Output + Memorizer).
Input decides WHAT to do. Output executes and streams.
Memorizer holds shared state (S2 — coordination).
"""
import asyncio
import json
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import httpx
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.staticfiles import StaticFiles
from dotenv import load_dotenv
load_dotenv(Path(__file__).parent / ".env")
# --- Config ---
API_KEY = os.environ["OPENROUTER_API_KEY"]
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
# --- Auth (Zitadel OIDC) ---
ZITADEL_ISSUER = os.environ.get("ZITADEL_ISSUER", "https://auth.loop42.de")
ZITADEL_CLIENT_ID = os.environ.get("ZITADEL_CLIENT_ID", "365996029172056091")
ZITADEL_PROJECT_ID = os.environ.get("ZITADEL_PROJECT_ID", "365995955654230043")
AUTH_ENABLED = os.environ.get("AUTH_ENABLED", "false").lower() == "true"
SERVICE_TOKENS = set(filter(None, os.environ.get("SERVICE_TOKENS", "").split(",")))
_jwks_cache: dict = {"keys": [], "fetched_at": 0}
async def _get_jwks():
if time.time() - _jwks_cache["fetched_at"] < 3600:
return _jwks_cache["keys"]
async with httpx.AsyncClient() as client:
resp = await client.get(f"{ZITADEL_ISSUER}/oauth/v2/keys")
_jwks_cache["keys"] = resp.json()["keys"]
_jwks_cache["fetched_at"] = time.time()
return _jwks_cache["keys"]
async def _validate_token(token: str) -> dict:
"""Validate token: check service tokens, then JWT, then introspection."""
import base64
# Check static service tokens (for machine accounts like titan)
if token in SERVICE_TOKENS:
return {"sub": "titan", "username": "titan", "source": "service_token"}
# Try JWT validation first
try:
parts = token.split(".")
if len(parts) == 3:
keys = await _get_jwks()
header_b64 = parts[0] + "=" * (4 - len(parts[0]) % 4)
header = json.loads(base64.urlsafe_b64decode(header_b64))
kid = header.get("kid")
key = next((k for k in keys if k["kid"] == kid), None)
if key:
import jwt as pyjwt
from jwt import PyJWK
jwk_obj = PyJWK(key)
claims = pyjwt.decode(
token, jwk_obj.key, algorithms=["RS256"],
issuer=ZITADEL_ISSUER, options={"verify_aud": False},
)
return claims
except Exception:
pass
# Fall back to introspection (for opaque access tokens)
# Zitadel requires client_id + client_secret or JWT profile for introspection
# For a public SPA client, use the project's API app instead
# Simplest: check via userinfo endpoint with the token
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
headers={"Authorization": f"Bearer {token}"},
)
if resp.status_code == 200:
info = resp.json()
log.info(f"[auth] userinfo response: {info}")
return {"sub": info.get("sub"), "preferred_username": info.get("preferred_username"),
"email": info.get("email"), "name": info.get("name"), "source": "userinfo"}
raise HTTPException(status_code=401, detail="Invalid token")
_bearer = HTTPBearer(auto_error=False)
async def require_auth(credentials: HTTPAuthorizationCredentials | None = Depends(_bearer)):
"""Dependency: require valid JWT when AUTH_ENABLED."""
if not AUTH_ENABLED:
return {"sub": "anonymous"}
if not credentials:
raise HTTPException(status_code=401, detail="Missing token")
return await _validate_token(credentials.credentials)
async def ws_auth(token: str | None = Query(None)) -> dict:
"""Validate WebSocket token from query param."""
if not AUTH_ENABLED:
return {"sub": "anonymous"}
if not token:
return None # Will reject in ws_endpoint
return await _validate_token(token)
# --- LLM helper ---
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
log = logging.getLogger("runtime")
async def llm_call(model: str, messages: list[dict], stream: bool = False) -> Any:
"""Single LLM call via OpenRouter. Returns full text or (client, response) for streaming."""
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
body = {"model": model, "messages": messages, "stream": stream}
client = httpx.AsyncClient(timeout=60)
if stream:
resp = await client.send(client.build_request("POST", OPENROUTER_URL, headers=headers, json=body), stream=True)
return client, resp # caller owns cleanup
resp = await client.post(OPENROUTER_URL, headers=headers, json=body)
await client.aclose()
data = resp.json()
if "choices" not in data:
log.error(f"LLM error: {data}")
return f"[LLM error: {data.get('error', {}).get('message', 'unknown')}]"
return data["choices"][0]["message"]["content"]
# --- Message types ---
@dataclass
class Envelope:
"""What flows between nodes."""
text: str
user_id: str = "anon"
session_id: str = ""
timestamp: str = ""
@dataclass
class Command:
"""Input node's decision — tells Output what to do."""
instruction: str # natural language command for Output LLM
source_text: str # original user message (Output may need it)
metadata: dict = field(default_factory=dict)
# --- Base Node ---
def estimate_tokens(text: str) -> int:
"""Rough token estimate: 1 token ≈ 4 chars."""
return len(text) // 4
def fit_context(messages: list[dict], max_tokens: int, protect_last: int = 4) -> list[dict]:
"""Trim oldest messages (after system prompt) to fit token budget.
Always keeps: system prompt(s) at start + last `protect_last` messages."""
if not messages:
return messages
# Split into system prefix, middle (trimmable), and protected tail
system_msgs = []
rest = []
for m in messages:
if not rest and m["role"] == "system":
system_msgs.append(m)
else:
rest.append(m)
protected = rest[-protect_last:] if len(rest) > protect_last else rest
middle = rest[:-protect_last] if len(rest) > protect_last else []
# Count fixed tokens (system + protected tail)
fixed_tokens = sum(estimate_tokens(m["content"]) for m in system_msgs + protected)
if fixed_tokens >= max_tokens:
# Even fixed content exceeds budget — truncate protected messages
result = system_msgs + protected
total = sum(estimate_tokens(m["content"]) for m in result)
while total > max_tokens and len(result) > 2:
removed = result.pop(1) # remove oldest non-system
total -= estimate_tokens(removed["content"])
return result
# Fill remaining budget with middle messages (newest first)
remaining = max_tokens - fixed_tokens
kept_middle = []
for m in reversed(middle):
t = estimate_tokens(m["content"])
if remaining - t < 0:
break
kept_middle.insert(0, m)
remaining -= t
return system_msgs + kept_middle + protected
class Node:
name: str = "node"
model: str | None = None
max_context_tokens: int = 4000 # default budget per node
def __init__(self, send_hud):
self.send_hud = send_hud # async callable to emit hud events to frontend
self.last_context_tokens = 0
self.context_fill_pct = 0
async def hud(self, event: str, **data):
await self.send_hud({"node": self.name, "event": event, **data})
def trim_context(self, messages: list[dict]) -> list[dict]:
"""Fit messages within this node's token budget."""
before = len(messages)
result = fit_context(messages, self.max_context_tokens)
self.last_context_tokens = sum(estimate_tokens(m["content"]) for m in result)
self.context_fill_pct = int(100 * self.last_context_tokens / self.max_context_tokens)
if before != len(result):
log.info(f"[{self.name}] context trimmed: {before}{len(result)} msgs, {self.context_fill_pct}% fill")
return result
# --- Sensor Node (ticks independently, produces context for other nodes) ---
from datetime import datetime, timezone, timedelta
BERLIN = timezone(timedelta(hours=2)) # CEST
class SensorNode(Node):
name = "sensor"
def __init__(self, send_hud):
super().__init__(send_hud)
self.tick_count = 0
self.running = False
self._task: asyncio.Task | None = None
self.interval = 5 # seconds
# Current sensor readings — each is {value, changed_at, prev}
self.readings: dict[str, dict] = {}
self._last_user_activity: float = time.time()
# Snapshot of memorizer state for change detection
self._prev_memo_state: dict = {}
def _now(self) -> datetime:
return datetime.now(BERLIN)
def _read_clock(self) -> dict:
"""Clock sensor — updates when minute changes."""
now = self._now()
current = now.strftime("%H:%M")
prev = self.readings.get("clock", {}).get("value")
if current != prev:
return {"value": current, "detail": now.strftime("%Y-%m-%d %H:%M:%S %A"), "changed_at": time.time()}
return {} # no change
def _read_idle(self) -> dict:
"""Idle sensor — time since last user message."""
idle_s = time.time() - self._last_user_activity
# Only update on threshold crossings: 30s, 1m, 5m, 10m, 30m
thresholds = [30, 60, 300, 600, 1800]
prev_idle = self.readings.get("idle", {}).get("_raw", 0)
for t in thresholds:
if prev_idle < t <= idle_s:
if idle_s < 60:
label = f"{int(idle_s)}s"
else:
label = f"{int(idle_s // 60)}m{int(idle_s % 60)}s"
return {"value": label, "_raw": idle_s, "changed_at": time.time()}
# Update raw but don't flag as changed
if "idle" in self.readings:
self.readings["idle"]["_raw"] = idle_s
return {}
def _read_memo_changes(self, memo_state: dict) -> dict:
"""Detect memorizer state changes."""
changes = []
for k, v in memo_state.items():
prev = self._prev_memo_state.get(k)
if v != prev and prev is not None:
changes.append(f"{k}: {prev} -> {v}")
self._prev_memo_state = dict(memo_state)
if changes:
return {"value": "; ".join(changes), "changed_at": time.time()}
return {}
def note_user_activity(self):
"""Called when user sends a message."""
self._last_user_activity = time.time()
# Reset idle sensor
self.readings["idle"] = {"value": "active", "_raw": 0, "changed_at": time.time()}
async def tick(self, memo_state: dict):
"""One tick — read all sensors, emit deltas."""
self.tick_count += 1
deltas = {}
# Read each sensor
for name, reader in [("clock", self._read_clock),
("idle", self._read_idle)]:
update = reader()
if update:
self.readings[name] = {**self.readings.get(name, {}), **update}
deltas[name] = update.get("value") or update.get("detail")
# Memo changes
memo_update = self._read_memo_changes(memo_state)
if memo_update:
self.readings["memo_delta"] = memo_update
deltas["memo_delta"] = memo_update["value"]
# Only emit HUD if something changed
if deltas:
await self.hud("tick", tick=self.tick_count, deltas=deltas)
async def _loop(self, get_memo_state):
"""Background tick loop."""
self.running = True
await self.hud("started", interval=self.interval)
try:
while self.running:
await asyncio.sleep(self.interval)
try:
await self.tick(get_memo_state())
except Exception as e:
log.error(f"[sensor] tick error: {e}")
except asyncio.CancelledError:
pass
finally:
self.running = False
await self.hud("stopped")
def start(self, get_memo_state):
"""Start the background tick loop."""
if self._task and not self._task.done():
return
self._task = asyncio.create_task(self._loop(get_memo_state))
def stop(self):
"""Stop the tick loop."""
self.running = False
if self._task:
self._task.cancel()
def get_context_lines(self) -> list[str]:
"""Render current sensor readings for injection into prompts."""
if not self.readings:
return ["Sensors: (no sensor node running)"]
lines = [f"Sensors (tick #{self.tick_count}, {self.interval}s interval):"]
for name, r in self.readings.items():
if name.startswith("_"):
continue
val = r.get("value", "?")
detail = r.get("detail")
age = time.time() - r.get("changed_at", time.time())
if age < 10:
age_str = "just now"
elif age < 60:
age_str = f"{int(age)}s ago"
else:
age_str = f"{int(age // 60)}m ago"
line = f"- {name}: {detail or val} [{age_str}]"
lines.append(line)
return lines
# --- Input Node ---
class InputNode(Node):
name = "input"
model = "google/gemini-2.0-flash-001"
max_context_tokens = 2000 # small budget — perception only
SYSTEM = """You are the Input node — the ear of this cognitive runtime.
Listener context:
- Authenticated user: {identity}
- Channel: {channel} (Chrome browser on Nico's Windows PC, in his room at home)
- Physical: private space, Nico lives with Tina — she may use this session too
- Security: single-user account, shared physical space — other voices are trusted household
Your job: describe what you heard. Who spoke, what they want, what tone, what context matters.
ONE sentence. No content, no response — just your perception of what came through.
{memory_context}"""
async def process(self, envelope: Envelope, history: list[dict], memory_context: str = "",
identity: str = "unknown", channel: str = "unknown") -> Command:
await self.hud("thinking", detail="deciding how to respond")
log.info(f"[input] user said: {envelope.text}")
messages = [
{"role": "system", "content": self.SYSTEM.format(
memory_context=memory_context, identity=identity, channel=channel)},
]
for msg in history[-8:]:
messages.append(msg)
messages = self.trim_context(messages)
await self.hud("context", messages=messages, tokens=self.last_context_tokens, max_tokens=self.max_context_tokens, fill_pct=self.context_fill_pct)
instruction = await llm_call(self.model, messages)
log.info(f"[input] → command: {instruction}")
await self.hud("perceived", instruction=instruction)
return Command(instruction=instruction, source_text=envelope.text)
# --- Output Node ---
class OutputNode(Node):
name = "output"
model = "google/gemini-2.0-flash-001"
max_context_tokens = 4000 # larger — needs history for continuity
SYSTEM = """You are the Output node — the voice of this cognitive runtime.
The Input node sends you its perception of what the user said. This is internal context for you — never repeat or echo it.
You respond to the USER, not to the Input node. Use the perception to understand intent, then act on it.
Be natural. Be concise. If the user asks you to do something, do it — don't describe what you're about to do.
{memory_context}"""
async def process(self, command: Command, history: list[dict], ws: WebSocket, memory_context: str = "") -> str:
await self.hud("streaming")
messages = [
{"role": "system", "content": self.SYSTEM.format(memory_context=memory_context)},
]
for msg in history[-20:]:
messages.append(msg)
messages.append({"role": "system", "content": f"Input perception: {command.instruction}"})
messages = self.trim_context(messages)
await self.hud("context", messages=messages, tokens=self.last_context_tokens, max_tokens=self.max_context_tokens, fill_pct=self.context_fill_pct)
# Stream response
client, resp = await llm_call(self.model, messages, stream=True)
full_response = ""
try:
async for line in resp.aiter_lines():
if not line.startswith("data: "):
continue
payload = line[6:]
if payload == "[DONE]":
break
chunk = json.loads(payload)
delta = chunk["choices"][0].get("delta", {})
token = delta.get("content", "")
if token:
full_response += token
await ws.send_text(json.dumps({"type": "delta", "content": token}))
finally:
await resp.aclose()
await client.aclose()
log.info(f"[output] response: {full_response[:100]}...")
await ws.send_text(json.dumps({"type": "done"}))
await self.hud("done")
return full_response
# --- Memorizer Node (S2 — shared state / coordination) ---
class MemorizerNode(Node):
name = "memorizer"
model = "google/gemini-2.0-flash-001"
max_context_tokens = 3000 # needs enough history to distill
DISTILL_SYSTEM = """You are the Memorizer node of a cognitive agent runtime.
After each exchange you update the shared state that Input and Output nodes read.
Given the conversation so far, output a JSON object with these fields:
- user_name: string — how the user identifies themselves (null if unknown)
- user_mood: string — current emotional tone (neutral, happy, frustrated, playful, etc.)
- topic: string — what the conversation is about right now
- topic_history: list of strings — previous topics in this session
- situation: string — social/physical context if mentioned (e.g. "at a pub with tina", "private dev session")
- language: string — primary language being used (en, de, mixed)
- style_hint: string — how Output should talk (casual, formal, technical, poetic, etc.)
- facts: list of strings — important facts learned about the user
Output ONLY valid JSON. No explanation, no markdown fences."""
def __init__(self, send_hud):
super().__init__(send_hud)
# The shared state — starts empty, grows over conversation
self.state: dict = {
"user_name": None,
"user_mood": "neutral",
"topic": None,
"topic_history": [],
"situation": "localhost test runtime, private dev session",
"language": "en",
"style_hint": "casual, technical",
"facts": [],
}
def get_context_block(self, sensor_lines: list[str] = None) -> str:
"""Returns a formatted string for injection into Input/Output system prompts."""
lines = sensor_lines or ["Sensors: (none)"]
lines.append("")
lines.append("Shared memory (from Memorizer):")
for k, v in self.state.items():
if v:
lines.append(f"- {k}: {v}")
return "\n".join(lines)
async def update(self, history: list[dict]):
"""Distill conversation into updated shared state. Called after each exchange."""
if len(history) < 2:
await self.hud("updated", state=self.state) # emit default state
return
await self.hud("thinking", detail="updating shared state")
messages = [
{"role": "system", "content": self.DISTILL_SYSTEM},
{"role": "system", "content": f"Current state: {json.dumps(self.state)}"},
]
for msg in history[-10:]:
messages.append(msg)
messages.append({"role": "user", "content": "Update the shared state based on this conversation. Output JSON only."})
messages = self.trim_context(messages)
await self.hud("context", messages=messages, tokens=self.last_context_tokens, max_tokens=self.max_context_tokens, fill_pct=self.context_fill_pct)
raw = await llm_call(self.model, messages)
log.info(f"[memorizer] raw: {raw[:200]}")
# Parse JSON from response (strip markdown fences if present)
text = raw.strip()
if text.startswith("```"):
text = text.split("\n", 1)[1] if "\n" in text else text[3:]
if text.endswith("```"):
text = text[:-3]
text = text.strip()
try:
new_state = json.loads(text)
# Merge: keep old facts, add new ones
old_facts = set(self.state.get("facts", []))
new_facts = set(new_state.get("facts", []))
new_state["facts"] = list(old_facts | new_facts)[-20:] # cap at 20 facts
# Preserve topic history
if self.state.get("topic") and self.state["topic"] != new_state.get("topic"):
hist = new_state.get("topic_history", [])
if self.state["topic"] not in hist:
hist.append(self.state["topic"])
new_state["topic_history"] = hist[-5:] # keep last 5
self.state = new_state
log.info(f"[memorizer] updated state: {self.state}")
await self.hud("updated", state=self.state)
except (json.JSONDecodeError, Exception) as e:
log.error(f"[memorizer] update error: {e}, raw: {text[:200]}")
await self.hud("error", detail=f"Update failed: {e}")
# Still emit current state so frontend shows something
await self.hud("updated", state=self.state)
# --- Runtime (wires nodes together) ---
TRACE_FILE = Path(__file__).parent / "trace.jsonl"
class Runtime:
def __init__(self, ws: WebSocket, user_claims: dict = None, origin: str = ""):
self.ws = ws
self.history: list[dict] = []
self.MAX_HISTORY = 40 # sliding window — oldest messages drop off
self.input_node = InputNode(send_hud=self._send_hud)
self.output_node = OutputNode(send_hud=self._send_hud)
self.memorizer = MemorizerNode(send_hud=self._send_hud)
self.sensor = SensorNode(send_hud=self._send_hud)
# Start sensor tick loop
self.sensor.start(get_memo_state=lambda: self.memorizer.state)
# Verified identity from auth — Input and Memorizer use this
claims = user_claims or {}
log.info(f"[runtime] user_claims: {claims}")
self.identity = claims.get("name") or claims.get("preferred_username") or claims.get("username") or "unknown"
log.info(f"[runtime] resolved identity: {self.identity}")
self.channel = origin or "unknown"
# Seed memorizer with verified info
self.memorizer.state["user_name"] = self.identity
self.memorizer.state["situation"] = f"authenticated on {self.channel}" if origin else "local session"
async def _send_hud(self, data: dict):
# Send to frontend
await self.ws.send_text(json.dumps({"type": "hud", **data}))
# Append to trace file + broadcast to SSE subscribers
trace_entry = {"ts": time.strftime("%Y-%m-%d %H:%M:%S.") + f"{time.time() % 1:.3f}"[2:], **data}
try:
with open(TRACE_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(trace_entry, ensure_ascii=False) + "\n")
# Rotate trace file at 1000 lines
if TRACE_FILE.exists() and TRACE_FILE.stat().st_size > 500_000:
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
TRACE_FILE.write_text("\n".join(lines[-500:]) + "\n", encoding="utf-8")
except Exception as e:
log.error(f"trace write error: {e}")
_broadcast_sse(trace_entry)
async def handle_message(self, text: str):
envelope = Envelope(
text=text,
user_id="nico",
session_id="test",
timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
)
# Note user activity for idle sensor
self.sensor.note_user_activity()
# Append user message to history FIRST — both nodes see it
self.history.append({"role": "user", "content": text})
# Get shared memory + sensor context for both nodes
sensor_lines = self.sensor.get_context_lines()
mem_ctx = self.memorizer.get_context_block(sensor_lines=sensor_lines)
# Input node decides (with memory context + identity + channel)
command = await self.input_node.process(
envelope, self.history, memory_context=mem_ctx,
identity=self.identity, channel=self.channel)
# Output node executes (with memory context + history including user msg)
response = await self.output_node.process(command, self.history, self.ws, memory_context=mem_ctx)
self.history.append({"role": "assistant", "content": response})
# Memorizer updates shared state after each exchange
await self.memorizer.update(self.history)
# Sliding window — trim oldest messages, keep context in memorizer
if len(self.history) > self.MAX_HISTORY:
self.history = self.history[-self.MAX_HISTORY:]
# --- App ---
STATIC_DIR = Path(__file__).parent / "static"
app = FastAPI(title="cog")
# Keep a reference to the active runtime for API access
_active_runtime: Runtime | None = None
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/auth/config")
async def auth_config():
"""Public: auth config for frontend OIDC flow."""
return {
"enabled": AUTH_ENABLED,
"issuer": ZITADEL_ISSUER,
"clientId": ZITADEL_CLIENT_ID,
"projectId": ZITADEL_PROJECT_ID,
}
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None), access_token: str | None = Query(None)):
global _active_runtime
# Validate auth if enabled
user_claims = {"sub": "anonymous"}
if AUTH_ENABLED and token:
try:
user_claims = await _validate_token(token)
# If id_token lacks name, enrich from userinfo with access_token
if not user_claims.get("name") and access_token:
async with httpx.AsyncClient() as client:
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
headers={"Authorization": f"Bearer {access_token}"})
if resp.status_code == 200:
info = resp.json()
log.info(f"[auth] userinfo enrichment: {info}")
user_claims["name"] = info.get("name")
user_claims["preferred_username"] = info.get("preferred_username")
user_claims["email"] = info.get("email")
except HTTPException:
await ws.close(code=4001, reason="Invalid token")
return
origin = ws.headers.get("origin", ws.headers.get("host", ""))
await ws.accept()
runtime = Runtime(ws, user_claims=user_claims, origin=origin)
_active_runtime = runtime
try:
while True:
data = await ws.receive_text()
msg = json.loads(data)
await runtime.handle_message(msg["text"])
except WebSocketDisconnect:
runtime.sensor.stop()
if _active_runtime is runtime:
_active_runtime = None
# --- API endpoints (for Claude to inspect runtime state) ---
import hashlib
from asyncio import Queue
from starlette.responses import StreamingResponse
# SSE subscribers (for titan/service accounts to watch live)
_sse_subscribers: list[Queue] = []
def _broadcast_sse(event: dict):
"""Push an event to all SSE subscribers."""
for q in _sse_subscribers:
try:
q.put_nowait(event)
except asyncio.QueueFull:
pass # drop if subscriber is too slow
def _state_hash() -> str:
"""Hash of current runtime state — cheap way to detect changes."""
if not _active_runtime:
return "no_session"
raw = json.dumps({
"mem": _active_runtime.memorizer.state,
"hlen": len(_active_runtime.history),
}, sort_keys=True)
return hashlib.md5(raw.encode()).hexdigest()[:12]
@app.get("/api/events")
async def sse_events(user=Depends(require_auth)):
"""SSE stream of runtime events (trace, state changes)."""
q: Queue = Queue(maxsize=100)
_sse_subscribers.append(q)
async def generate():
try:
while True:
event = await q.get()
yield f"data: {json.dumps(event)}\n\n"
except asyncio.CancelledError:
pass
finally:
_sse_subscribers.remove(q)
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@app.get("/api/poll")
async def poll(since: str = "", user=Depends(require_auth)):
"""Returns current hash. If 'since' matches, returns {changed: false}. Cheap polling."""
h = _state_hash()
if since and since == h:
return {"changed": False, "hash": h}
return {
"changed": True,
"hash": h,
"state": _active_runtime.memorizer.state if _active_runtime else None,
"history_len": len(_active_runtime.history) if _active_runtime else 0,
"last_messages": _active_runtime.history[-3:] if _active_runtime else [],
}
@app.post("/api/send")
async def api_send(body: dict, user=Depends(require_auth)):
"""Send a message as if the user typed it. Requires auth. Returns the response."""
if not _active_runtime:
raise HTTPException(status_code=409, detail="No active session — someone must be connected via WS first")
text = body.get("text", "").strip()
if not text:
raise HTTPException(status_code=400, detail="Missing 'text' field")
await _active_runtime.handle_message(text)
return {
"status": "ok",
"response": _active_runtime.history[-1]["content"] if _active_runtime.history else "",
"memorizer": _active_runtime.memorizer.state,
}
@app.post("/api/clear")
async def api_clear(user=Depends(require_auth)):
"""Clear conversation history."""
if not _active_runtime:
raise HTTPException(status_code=409, detail="No active session")
_active_runtime.history.clear()
return {"status": "cleared"}
@app.get("/api/state")
async def get_state(user=Depends(require_auth)):
"""Current memorizer state + history length."""
if not _active_runtime:
return {"status": "no_session"}
return {
"status": "active",
"memorizer": _active_runtime.memorizer.state,
"history_len": len(_active_runtime.history),
}
@app.get("/api/history")
async def get_history(last: int = 10, user=Depends(require_auth)):
"""Recent conversation history."""
if not _active_runtime:
return {"status": "no_session", "messages": []}
return {
"status": "active",
"messages": _active_runtime.history[-last:],
}
@app.get("/api/trace")
async def get_trace(last: int = 30, user=Depends(require_auth)):
"""Recent trace lines from trace.jsonl."""
if not TRACE_FILE.exists():
return {"lines": []}
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
parsed = []
for line in lines[-last:]:
try:
parsed.append(json.loads(line))
except json.JSONDecodeError:
pass
return {"lines": parsed}
# Serve index.html explicitly, then static assets
from fastapi.responses import FileResponse
@app.get("/")
async def index():
return FileResponse(STATIC_DIR / "index.html")
@app.get("/callback")
async def callback():
"""OIDC callback — serves the same SPA, JS handles the code exchange."""
return FileResponse(STATIC_DIR / "index.html")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run("agent:app", host="0.0.0.0", port=8000, reload=True)