""" Cognitive Agent Runtime — Phase A.2: Three-node graph (Input → Output + Memorizer). Input decides WHAT to do. Output executes and streams. Memorizer holds shared state (S2 — coordination). """ import asyncio import json import os import time from dataclasses import dataclass, field from pathlib import Path from typing import Any import httpx from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.staticfiles import StaticFiles from dotenv import load_dotenv load_dotenv(Path(__file__).parent / ".env") # --- Config --- API_KEY = os.environ["OPENROUTER_API_KEY"] OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions" # --- Auth (Zitadel OIDC) --- ZITADEL_ISSUER = os.environ.get("ZITADEL_ISSUER", "https://auth.loop42.de") ZITADEL_CLIENT_ID = os.environ.get("ZITADEL_CLIENT_ID", "365996029172056091") ZITADEL_PROJECT_ID = os.environ.get("ZITADEL_PROJECT_ID", "365995955654230043") AUTH_ENABLED = os.environ.get("AUTH_ENABLED", "false").lower() == "true" SERVICE_TOKENS = set(filter(None, os.environ.get("SERVICE_TOKENS", "").split(","))) _jwks_cache: dict = {"keys": [], "fetched_at": 0} async def _get_jwks(): if time.time() - _jwks_cache["fetched_at"] < 3600: return _jwks_cache["keys"] async with httpx.AsyncClient() as client: resp = await client.get(f"{ZITADEL_ISSUER}/oauth/v2/keys") _jwks_cache["keys"] = resp.json()["keys"] _jwks_cache["fetched_at"] = time.time() return _jwks_cache["keys"] async def _validate_token(token: str) -> dict: """Validate token: check service tokens, then JWT, then introspection.""" import base64 # Check static service tokens (for machine accounts like titan) if token in SERVICE_TOKENS: return {"sub": "titan", "username": "titan", "source": "service_token"} # Try JWT validation first try: parts = token.split(".") if len(parts) == 3: keys = await _get_jwks() header_b64 = parts[0] + "=" * (4 - len(parts[0]) % 4) header = json.loads(base64.urlsafe_b64decode(header_b64)) kid = header.get("kid") key = next((k for k in keys if k["kid"] == kid), None) if key: import jwt as pyjwt from jwt import PyJWK jwk_obj = PyJWK(key) claims = pyjwt.decode( token, jwk_obj.key, algorithms=["RS256"], issuer=ZITADEL_ISSUER, options={"verify_aud": False}, ) return claims except Exception: pass # Fall back to introspection (for opaque access tokens) # Zitadel requires client_id + client_secret or JWT profile for introspection # For a public SPA client, use the project's API app instead # Simplest: check via userinfo endpoint with the token async with httpx.AsyncClient() as client: resp = await client.get( f"{ZITADEL_ISSUER}/oidc/v1/userinfo", headers={"Authorization": f"Bearer {token}"}, ) if resp.status_code == 200: info = resp.json() log.info(f"[auth] userinfo response: {info}") return {"sub": info.get("sub"), "preferred_username": info.get("preferred_username"), "email": info.get("email"), "name": info.get("name"), "source": "userinfo"} raise HTTPException(status_code=401, detail="Invalid token") _bearer = HTTPBearer(auto_error=False) async def require_auth(credentials: HTTPAuthorizationCredentials | None = Depends(_bearer)): """Dependency: require valid JWT when AUTH_ENABLED.""" if not AUTH_ENABLED: return {"sub": "anonymous"} if not credentials: raise HTTPException(status_code=401, detail="Missing token") return await _validate_token(credentials.credentials) async def ws_auth(token: str | None = Query(None)) -> dict: """Validate WebSocket token from query param.""" if not AUTH_ENABLED: return {"sub": "anonymous"} if not token: return None # Will reject in ws_endpoint return await _validate_token(token) # --- LLM helper --- import logging logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S") log = logging.getLogger("runtime") async def llm_call(model: str, messages: list[dict], stream: bool = False) -> Any: """Single LLM call via OpenRouter. Returns full text or (client, response) for streaming.""" headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"} body = {"model": model, "messages": messages, "stream": stream} client = httpx.AsyncClient(timeout=60) if stream: resp = await client.send(client.build_request("POST", OPENROUTER_URL, headers=headers, json=body), stream=True) return client, resp # caller owns cleanup resp = await client.post(OPENROUTER_URL, headers=headers, json=body) await client.aclose() data = resp.json() if "choices" not in data: log.error(f"LLM error: {data}") return f"[LLM error: {data.get('error', {}).get('message', 'unknown')}]" return data["choices"][0]["message"]["content"] # --- Message types --- @dataclass class Envelope: """What flows between nodes.""" text: str user_id: str = "anon" session_id: str = "" timestamp: str = "" @dataclass class Command: """Input node's decision — tells Output what to do.""" instruction: str # natural language command for Output LLM source_text: str # original user message (Output may need it) metadata: dict = field(default_factory=dict) # --- Base Node --- class Node: name: str = "node" model: str | None = None def __init__(self, send_hud): self.send_hud = send_hud # async callable to emit hud events to frontend async def hud(self, event: str, **data): await self.send_hud({"node": self.name, "event": event, **data}) # --- Input Node --- class InputNode(Node): name = "input" model = "google/gemini-2.0-flash-001" SYSTEM = """You are the Input node — the ear of this cognitive runtime. Listener context: - Authenticated user: {identity} - Channel: {channel} (Chrome browser on Nico's Windows PC, in his room at home) - Physical: private space, Nico lives with Tina — she may use this session too - Security: single-user account, shared physical space — other voices are trusted household You hear what comes through this channel. Emit ONE instruction sentence telling Output how to respond. No content, just the command. {memory_context}""" async def process(self, envelope: Envelope, history: list[dict], memory_context: str = "", identity: str = "unknown", channel: str = "unknown") -> Command: await self.hud("thinking", detail="deciding how to respond") log.info(f"[input] user said: {envelope.text}") messages = [ {"role": "system", "content": self.SYSTEM.format( memory_context=memory_context, identity=identity, channel=channel)}, ] # History already includes current user message — don't add it again for msg in history[-8:]: messages.append(msg) await self.hud("context", messages=messages) instruction = await llm_call(self.model, messages) log.info(f"[input] → command: {instruction}") await self.hud("decided", instruction=instruction) return Command(instruction=instruction, source_text=envelope.text) # --- Output Node --- class OutputNode(Node): name = "output" model = "google/gemini-2.0-flash-001" SYSTEM = """You are the Output node of a cognitive agent runtime. You receive a command from the Input node telling you HOW to respond, plus the user's original message. Follow the command's tone and intent. Be natural, don't mention the command or the runtime architecture. Be concise. {memory_context}""" async def process(self, command: Command, history: list[dict], ws: WebSocket, memory_context: str = "") -> str: await self.hud("streaming") messages = [ {"role": "system", "content": self.SYSTEM.format(memory_context=memory_context)}, ] # Conversation history for continuity (already includes current user message) for msg in history[-20:]: messages.append(msg) # Inject command as system guidance after the user message messages.append({"role": "system", "content": f"Input node command: {command.instruction}"}) await self.hud("context", messages=messages) # Stream response client, resp = await llm_call(self.model, messages, stream=True) full_response = "" try: async for line in resp.aiter_lines(): if not line.startswith("data: "): continue payload = line[6:] if payload == "[DONE]": break chunk = json.loads(payload) delta = chunk["choices"][0].get("delta", {}) token = delta.get("content", "") if token: full_response += token await ws.send_text(json.dumps({"type": "delta", "content": token})) finally: await resp.aclose() await client.aclose() log.info(f"[output] response: {full_response[:100]}...") await ws.send_text(json.dumps({"type": "done"})) await self.hud("done") return full_response # --- Memorizer Node (S2 — shared state / coordination) --- class MemorizerNode(Node): name = "memorizer" model = "google/gemini-2.0-flash-001" DISTILL_SYSTEM = """You are the Memorizer node of a cognitive agent runtime. After each exchange you update the shared state that Input and Output nodes read. Given the conversation so far, output a JSON object with these fields: - user_name: string — how the user identifies themselves (null if unknown) - user_mood: string — current emotional tone (neutral, happy, frustrated, playful, etc.) - topic: string — what the conversation is about right now - topic_history: list of strings — previous topics in this session - situation: string — social/physical context if mentioned (e.g. "at a pub with tina", "private dev session") - language: string — primary language being used (en, de, mixed) - style_hint: string — how Output should talk (casual, formal, technical, poetic, etc.) - facts: list of strings — important facts learned about the user Output ONLY valid JSON. No explanation, no markdown fences.""" def __init__(self, send_hud): super().__init__(send_hud) # The shared state — starts empty, grows over conversation self.state: dict = { "user_name": None, "user_mood": "neutral", "topic": None, "topic_history": [], "situation": "localhost test runtime, private dev session", "language": "en", "style_hint": "casual, technical", "facts": [], } def get_context_block(self) -> str: """Returns a formatted string for injection into Input/Output system prompts.""" lines = ["Shared memory (from Memorizer):"] for k, v in self.state.items(): if v: lines.append(f"- {k}: {v}") return "\n".join(lines) async def update(self, history: list[dict]): """Distill conversation into updated shared state. Called after each exchange.""" if len(history) < 2: await self.hud("updated", state=self.state) # emit default state return await self.hud("thinking", detail="updating shared state") messages = [ {"role": "system", "content": self.DISTILL_SYSTEM}, {"role": "system", "content": f"Current state: {json.dumps(self.state)}"}, ] # Last few exchanges for distillation for msg in history[-10:]: messages.append(msg) messages.append({"role": "user", "content": "Update the shared state based on this conversation. Output JSON only."}) await self.hud("context", messages=messages) raw = await llm_call(self.model, messages) log.info(f"[memorizer] raw: {raw[:200]}") # Parse JSON from response (strip markdown fences if present) text = raw.strip() if text.startswith("```"): text = text.split("\n", 1)[1] if "\n" in text else text[3:] if text.endswith("```"): text = text[:-3] text = text.strip() try: new_state = json.loads(text) # Merge: keep old facts, add new ones old_facts = set(self.state.get("facts", [])) new_facts = set(new_state.get("facts", [])) new_state["facts"] = list(old_facts | new_facts) # Preserve topic history if self.state.get("topic") and self.state["topic"] != new_state.get("topic"): hist = new_state.get("topic_history", []) if self.state["topic"] not in hist: hist.append(self.state["topic"]) new_state["topic_history"] = hist[-5:] # keep last 5 self.state = new_state log.info(f"[memorizer] updated state: {self.state}") await self.hud("updated", state=self.state) except (json.JSONDecodeError, Exception) as e: log.error(f"[memorizer] update error: {e}, raw: {text[:200]}") await self.hud("error", detail=f"Update failed: {e}") # Still emit current state so frontend shows something await self.hud("updated", state=self.state) # --- Runtime (wires nodes together) --- TRACE_FILE = Path(__file__).parent / "trace.jsonl" class Runtime: def __init__(self, ws: WebSocket, user_claims: dict = None, origin: str = ""): self.ws = ws self.history: list[dict] = [] self.input_node = InputNode(send_hud=self._send_hud) self.output_node = OutputNode(send_hud=self._send_hud) self.memorizer = MemorizerNode(send_hud=self._send_hud) # Verified identity from auth — Input and Memorizer use this claims = user_claims or {} log.info(f"[runtime] user_claims: {claims}") self.identity = claims.get("name") or claims.get("preferred_username") or claims.get("username") or "unknown" log.info(f"[runtime] resolved identity: {self.identity}") self.channel = origin or "unknown" # Seed memorizer with verified info self.memorizer.state["user_name"] = self.identity self.memorizer.state["situation"] = f"authenticated on {self.channel}" if origin else "local session" async def _send_hud(self, data: dict): # Send to frontend await self.ws.send_text(json.dumps({"type": "hud", **data})) # Append to trace file + broadcast to SSE subscribers trace_entry = {"ts": time.strftime("%Y-%m-%d %H:%M:%S.") + f"{time.time() % 1:.3f}"[2:], **data} try: with open(TRACE_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(trace_entry, ensure_ascii=False) + "\n") except Exception as e: log.error(f"trace write error: {e}") _broadcast_sse(trace_entry) async def handle_message(self, text: str): envelope = Envelope( text=text, user_id="nico", session_id="test", timestamp=time.strftime("%Y-%m-%d %H:%M:%S"), ) # Append user message to history FIRST — both nodes see it self.history.append({"role": "user", "content": text}) # Get shared memory context for both nodes mem_ctx = self.memorizer.get_context_block() # Input node decides (with memory context + identity + channel) command = await self.input_node.process( envelope, self.history, memory_context=mem_ctx, identity=self.identity, channel=self.channel) # Output node executes (with memory context + history including user msg) response = await self.output_node.process(command, self.history, self.ws, memory_context=mem_ctx) self.history.append({"role": "assistant", "content": response}) # Memorizer updates shared state after each exchange await self.memorizer.update(self.history) # --- App --- STATIC_DIR = Path(__file__).parent / "static" app = FastAPI(title="Cognitive Agent Runtime") # Keep a reference to the active runtime for API access _active_runtime: Runtime | None = None @app.get("/health") async def health(): return {"status": "ok"} @app.get("/auth/config") async def auth_config(): """Public: auth config for frontend OIDC flow.""" return { "enabled": AUTH_ENABLED, "issuer": ZITADEL_ISSUER, "clientId": ZITADEL_CLIENT_ID, "projectId": ZITADEL_PROJECT_ID, } @app.websocket("/ws") async def ws_endpoint(ws: WebSocket, token: str | None = Query(None), access_token: str | None = Query(None)): global _active_runtime # Validate auth if enabled user_claims = {"sub": "anonymous"} if AUTH_ENABLED and token: try: user_claims = await _validate_token(token) # If id_token lacks name, enrich from userinfo with access_token if not user_claims.get("name") and access_token: async with httpx.AsyncClient() as client: resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo", headers={"Authorization": f"Bearer {access_token}"}) if resp.status_code == 200: info = resp.json() log.info(f"[auth] userinfo enrichment: {info}") user_claims["name"] = info.get("name") user_claims["preferred_username"] = info.get("preferred_username") user_claims["email"] = info.get("email") except HTTPException: await ws.close(code=4001, reason="Invalid token") return origin = ws.headers.get("origin", ws.headers.get("host", "")) await ws.accept() runtime = Runtime(ws, user_claims=user_claims, origin=origin) _active_runtime = runtime try: while True: data = await ws.receive_text() msg = json.loads(data) await runtime.handle_message(msg["text"]) except WebSocketDisconnect: if _active_runtime is runtime: _active_runtime = None # --- API endpoints (for Claude to inspect runtime state) --- import hashlib from asyncio import Queue from starlette.responses import StreamingResponse # SSE subscribers (for titan/service accounts to watch live) _sse_subscribers: list[Queue] = [] def _broadcast_sse(event: dict): """Push an event to all SSE subscribers.""" for q in _sse_subscribers: try: q.put_nowait(event) except asyncio.QueueFull: pass # drop if subscriber is too slow def _state_hash() -> str: """Hash of current runtime state — cheap way to detect changes.""" if not _active_runtime: return "no_session" raw = json.dumps({ "mem": _active_runtime.memorizer.state, "hlen": len(_active_runtime.history), }, sort_keys=True) return hashlib.md5(raw.encode()).hexdigest()[:12] @app.get("/api/events") async def sse_events(user=Depends(require_auth)): """SSE stream of runtime events (trace, state changes).""" q: Queue = Queue(maxsize=100) _sse_subscribers.append(q) async def generate(): try: while True: event = await q.get() yield f"data: {json.dumps(event)}\n\n" except asyncio.CancelledError: pass finally: _sse_subscribers.remove(q) return StreamingResponse(generate(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) @app.get("/api/poll") async def poll(since: str = "", user=Depends(require_auth)): """Returns current hash. If 'since' matches, returns {changed: false}. Cheap polling.""" h = _state_hash() if since and since == h: return {"changed": False, "hash": h} return { "changed": True, "hash": h, "state": _active_runtime.memorizer.state if _active_runtime else None, "history_len": len(_active_runtime.history) if _active_runtime else 0, "last_messages": _active_runtime.history[-3:] if _active_runtime else [], } @app.get("/api/state") async def get_state(user=Depends(require_auth)): """Current memorizer state + history length.""" if not _active_runtime: return {"status": "no_session"} return { "status": "active", "memorizer": _active_runtime.memorizer.state, "history_len": len(_active_runtime.history), } @app.get("/api/history") async def get_history(last: int = 10, user=Depends(require_auth)): """Recent conversation history.""" if not _active_runtime: return {"status": "no_session", "messages": []} return { "status": "active", "messages": _active_runtime.history[-last:], } @app.get("/api/trace") async def get_trace(last: int = 30, user=Depends(require_auth)): """Recent trace lines from trace.jsonl.""" if not TRACE_FILE.exists(): return {"lines": []} lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n") parsed = [] for line in lines[-last:]: try: parsed.append(json.loads(line)) except json.JSONDecodeError: pass return {"lines": parsed} # Serve index.html explicitly, then static assets from fastapi.responses import FileResponse @app.get("/") async def index(): return FileResponse(STATIC_DIR / "index.html") @app.get("/callback") async def callback(): """OIDC callback — serves the same SPA, JS handles the code exchange.""" return FileResponse(STATIC_DIR / "index.html") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") if __name__ == "__main__": import uvicorn uvicorn.run("agent:app", host="0.0.0.0", port=8000, reload=True)