- SensorNode: 5s tick loop with delta-only emissions (clock, idle, memo changes) - Input reframed as perceiver (describes what it heard, not commands) - Output reframed as voice (acts on perception, never echoes it) - Per-node token budgets: Input 2K, Output 4K, Memorizer 3K - fit_context() trims oldest messages to stay within budget - History sliding window: 40 messages max - Facts capped at 20, trace file rotates at 500KB - /api/send + /api/clear endpoints for programmatic testing - test_cog.py test suite - Listener context: physical/social/security awareness Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
847 lines
32 KiB
Python
847 lines
32 KiB
Python
"""
|
|
Cognitive Agent Runtime — Phase A.2: Three-node graph (Input → Output + Memorizer).
|
|
Input decides WHAT to do. Output executes and streams.
|
|
Memorizer holds shared state (S2 — coordination).
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from dotenv import load_dotenv
|
|
load_dotenv(Path(__file__).parent / ".env")
|
|
|
|
# --- Config ---
|
|
|
|
API_KEY = os.environ["OPENROUTER_API_KEY"]
|
|
OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"
|
|
|
|
# --- Auth (Zitadel OIDC) ---
|
|
|
|
ZITADEL_ISSUER = os.environ.get("ZITADEL_ISSUER", "https://auth.loop42.de")
|
|
ZITADEL_CLIENT_ID = os.environ.get("ZITADEL_CLIENT_ID", "365996029172056091")
|
|
ZITADEL_PROJECT_ID = os.environ.get("ZITADEL_PROJECT_ID", "365995955654230043")
|
|
AUTH_ENABLED = os.environ.get("AUTH_ENABLED", "false").lower() == "true"
|
|
SERVICE_TOKENS = set(filter(None, os.environ.get("SERVICE_TOKENS", "").split(",")))
|
|
|
|
_jwks_cache: dict = {"keys": [], "fetched_at": 0}
|
|
|
|
async def _get_jwks():
|
|
if time.time() - _jwks_cache["fetched_at"] < 3600:
|
|
return _jwks_cache["keys"]
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(f"{ZITADEL_ISSUER}/oauth/v2/keys")
|
|
_jwks_cache["keys"] = resp.json()["keys"]
|
|
_jwks_cache["fetched_at"] = time.time()
|
|
return _jwks_cache["keys"]
|
|
|
|
async def _validate_token(token: str) -> dict:
|
|
"""Validate token: check service tokens, then JWT, then introspection."""
|
|
import base64
|
|
|
|
# Check static service tokens (for machine accounts like titan)
|
|
if token in SERVICE_TOKENS:
|
|
return {"sub": "titan", "username": "titan", "source": "service_token"}
|
|
|
|
# Try JWT validation first
|
|
try:
|
|
parts = token.split(".")
|
|
if len(parts) == 3:
|
|
keys = await _get_jwks()
|
|
header_b64 = parts[0] + "=" * (4 - len(parts[0]) % 4)
|
|
header = json.loads(base64.urlsafe_b64decode(header_b64))
|
|
kid = header.get("kid")
|
|
key = next((k for k in keys if k["kid"] == kid), None)
|
|
if key:
|
|
import jwt as pyjwt
|
|
from jwt import PyJWK
|
|
jwk_obj = PyJWK(key)
|
|
claims = pyjwt.decode(
|
|
token, jwk_obj.key, algorithms=["RS256"],
|
|
issuer=ZITADEL_ISSUER, options={"verify_aud": False},
|
|
)
|
|
return claims
|
|
except Exception:
|
|
pass
|
|
|
|
# Fall back to introspection (for opaque access tokens)
|
|
# Zitadel requires client_id + client_secret or JWT profile for introspection
|
|
# For a public SPA client, use the project's API app instead
|
|
# Simplest: check via userinfo endpoint with the token
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(
|
|
f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
)
|
|
if resp.status_code == 200:
|
|
info = resp.json()
|
|
log.info(f"[auth] userinfo response: {info}")
|
|
return {"sub": info.get("sub"), "preferred_username": info.get("preferred_username"),
|
|
"email": info.get("email"), "name": info.get("name"), "source": "userinfo"}
|
|
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
_bearer = HTTPBearer(auto_error=False)
|
|
|
|
async def require_auth(credentials: HTTPAuthorizationCredentials | None = Depends(_bearer)):
|
|
"""Dependency: require valid JWT when AUTH_ENABLED."""
|
|
if not AUTH_ENABLED:
|
|
return {"sub": "anonymous"}
|
|
if not credentials:
|
|
raise HTTPException(status_code=401, detail="Missing token")
|
|
return await _validate_token(credentials.credentials)
|
|
|
|
async def ws_auth(token: str | None = Query(None)) -> dict:
|
|
"""Validate WebSocket token from query param."""
|
|
if not AUTH_ENABLED:
|
|
return {"sub": "anonymous"}
|
|
if not token:
|
|
return None # Will reject in ws_endpoint
|
|
return await _validate_token(token)
|
|
|
|
# --- LLM helper ---
|
|
|
|
import logging
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
|
|
log = logging.getLogger("runtime")
|
|
|
|
|
|
async def llm_call(model: str, messages: list[dict], stream: bool = False) -> Any:
|
|
"""Single LLM call via OpenRouter. Returns full text or (client, response) for streaming."""
|
|
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
|
|
body = {"model": model, "messages": messages, "stream": stream}
|
|
|
|
client = httpx.AsyncClient(timeout=60)
|
|
if stream:
|
|
resp = await client.send(client.build_request("POST", OPENROUTER_URL, headers=headers, json=body), stream=True)
|
|
return client, resp # caller owns cleanup
|
|
|
|
resp = await client.post(OPENROUTER_URL, headers=headers, json=body)
|
|
await client.aclose()
|
|
data = resp.json()
|
|
if "choices" not in data:
|
|
log.error(f"LLM error: {data}")
|
|
return f"[LLM error: {data.get('error', {}).get('message', 'unknown')}]"
|
|
return data["choices"][0]["message"]["content"]
|
|
|
|
|
|
# --- Message types ---
|
|
|
|
@dataclass
|
|
class Envelope:
|
|
"""What flows between nodes."""
|
|
text: str
|
|
user_id: str = "anon"
|
|
session_id: str = ""
|
|
timestamp: str = ""
|
|
|
|
|
|
@dataclass
|
|
class Command:
|
|
"""Input node's decision — tells Output what to do."""
|
|
instruction: str # natural language command for Output LLM
|
|
source_text: str # original user message (Output may need it)
|
|
metadata: dict = field(default_factory=dict)
|
|
|
|
|
|
# --- Base Node ---
|
|
|
|
def estimate_tokens(text: str) -> int:
|
|
"""Rough token estimate: 1 token ≈ 4 chars."""
|
|
return len(text) // 4
|
|
|
|
|
|
def fit_context(messages: list[dict], max_tokens: int, protect_last: int = 4) -> list[dict]:
|
|
"""Trim oldest messages (after system prompt) to fit token budget.
|
|
Always keeps: system prompt(s) at start + last `protect_last` messages."""
|
|
if not messages:
|
|
return messages
|
|
|
|
# Split into system prefix, middle (trimmable), and protected tail
|
|
system_msgs = []
|
|
rest = []
|
|
for m in messages:
|
|
if not rest and m["role"] == "system":
|
|
system_msgs.append(m)
|
|
else:
|
|
rest.append(m)
|
|
|
|
protected = rest[-protect_last:] if len(rest) > protect_last else rest
|
|
middle = rest[:-protect_last] if len(rest) > protect_last else []
|
|
|
|
# Count fixed tokens (system + protected tail)
|
|
fixed_tokens = sum(estimate_tokens(m["content"]) for m in system_msgs + protected)
|
|
|
|
if fixed_tokens >= max_tokens:
|
|
# Even fixed content exceeds budget — truncate protected messages
|
|
result = system_msgs + protected
|
|
total = sum(estimate_tokens(m["content"]) for m in result)
|
|
while total > max_tokens and len(result) > 2:
|
|
removed = result.pop(1) # remove oldest non-system
|
|
total -= estimate_tokens(removed["content"])
|
|
return result
|
|
|
|
# Fill remaining budget with middle messages (newest first)
|
|
remaining = max_tokens - fixed_tokens
|
|
kept_middle = []
|
|
for m in reversed(middle):
|
|
t = estimate_tokens(m["content"])
|
|
if remaining - t < 0:
|
|
break
|
|
kept_middle.insert(0, m)
|
|
remaining -= t
|
|
|
|
return system_msgs + kept_middle + protected
|
|
|
|
|
|
class Node:
|
|
name: str = "node"
|
|
model: str | None = None
|
|
max_context_tokens: int = 4000 # default budget per node
|
|
|
|
def __init__(self, send_hud):
|
|
self.send_hud = send_hud # async callable to emit hud events to frontend
|
|
self.last_context_tokens = 0
|
|
|
|
async def hud(self, event: str, **data):
|
|
await self.send_hud({"node": self.name, "event": event, **data})
|
|
|
|
def trim_context(self, messages: list[dict]) -> list[dict]:
|
|
"""Fit messages within this node's token budget."""
|
|
before = len(messages)
|
|
result = fit_context(messages, self.max_context_tokens)
|
|
self.last_context_tokens = sum(estimate_tokens(m["content"]) for m in result)
|
|
self.context_fill_pct = int(100 * self.last_context_tokens / self.max_context_tokens)
|
|
if before != len(result):
|
|
log.info(f"[{self.name}] context trimmed: {before} → {len(result)} msgs, {self.context_fill_pct}% fill")
|
|
return result
|
|
|
|
|
|
# --- Sensor Node (ticks independently, produces context for other nodes) ---
|
|
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
BERLIN = timezone(timedelta(hours=2)) # CEST
|
|
|
|
|
|
class SensorNode(Node):
|
|
name = "sensor"
|
|
|
|
def __init__(self, send_hud):
|
|
super().__init__(send_hud)
|
|
self.tick_count = 0
|
|
self.running = False
|
|
self._task: asyncio.Task | None = None
|
|
self.interval = 5 # seconds
|
|
# Current sensor readings — each is {value, changed_at, prev}
|
|
self.readings: dict[str, dict] = {}
|
|
self._last_user_activity: float = time.time()
|
|
# Snapshot of memorizer state for change detection
|
|
self._prev_memo_state: dict = {}
|
|
|
|
def _now(self) -> datetime:
|
|
return datetime.now(BERLIN)
|
|
|
|
def _read_clock(self) -> dict:
|
|
"""Clock sensor — updates when minute changes."""
|
|
now = self._now()
|
|
current = now.strftime("%H:%M")
|
|
prev = self.readings.get("clock", {}).get("value")
|
|
if current != prev:
|
|
return {"value": current, "detail": now.strftime("%Y-%m-%d %H:%M:%S %A"), "changed_at": time.time()}
|
|
return {} # no change
|
|
|
|
def _read_idle(self) -> dict:
|
|
"""Idle sensor — time since last user message."""
|
|
idle_s = time.time() - self._last_user_activity
|
|
# Only update on threshold crossings: 30s, 1m, 5m, 10m, 30m
|
|
thresholds = [30, 60, 300, 600, 1800]
|
|
prev_idle = self.readings.get("idle", {}).get("_raw", 0)
|
|
for t in thresholds:
|
|
if prev_idle < t <= idle_s:
|
|
if idle_s < 60:
|
|
label = f"{int(idle_s)}s"
|
|
else:
|
|
label = f"{int(idle_s // 60)}m{int(idle_s % 60)}s"
|
|
return {"value": label, "_raw": idle_s, "changed_at": time.time()}
|
|
# Update raw but don't flag as changed
|
|
if "idle" in self.readings:
|
|
self.readings["idle"]["_raw"] = idle_s
|
|
return {}
|
|
|
|
def _read_memo_changes(self, memo_state: dict) -> dict:
|
|
"""Detect memorizer state changes."""
|
|
changes = []
|
|
for k, v in memo_state.items():
|
|
prev = self._prev_memo_state.get(k)
|
|
if v != prev and prev is not None:
|
|
changes.append(f"{k}: {prev} -> {v}")
|
|
self._prev_memo_state = dict(memo_state)
|
|
if changes:
|
|
return {"value": "; ".join(changes), "changed_at": time.time()}
|
|
return {}
|
|
|
|
def note_user_activity(self):
|
|
"""Called when user sends a message."""
|
|
self._last_user_activity = time.time()
|
|
# Reset idle sensor
|
|
self.readings["idle"] = {"value": "active", "_raw": 0, "changed_at": time.time()}
|
|
|
|
async def tick(self, memo_state: dict):
|
|
"""One tick — read all sensors, emit deltas."""
|
|
self.tick_count += 1
|
|
deltas = {}
|
|
|
|
# Read each sensor
|
|
for name, reader in [("clock", self._read_clock),
|
|
("idle", self._read_idle)]:
|
|
update = reader()
|
|
if update:
|
|
self.readings[name] = {**self.readings.get(name, {}), **update}
|
|
deltas[name] = update.get("value") or update.get("detail")
|
|
|
|
# Memo changes
|
|
memo_update = self._read_memo_changes(memo_state)
|
|
if memo_update:
|
|
self.readings["memo_delta"] = memo_update
|
|
deltas["memo_delta"] = memo_update["value"]
|
|
|
|
# Only emit HUD if something changed
|
|
if deltas:
|
|
await self.hud("tick", tick=self.tick_count, deltas=deltas)
|
|
|
|
async def _loop(self, get_memo_state):
|
|
"""Background tick loop."""
|
|
self.running = True
|
|
await self.hud("started", interval=self.interval)
|
|
try:
|
|
while self.running:
|
|
await asyncio.sleep(self.interval)
|
|
try:
|
|
await self.tick(get_memo_state())
|
|
except Exception as e:
|
|
log.error(f"[sensor] tick error: {e}")
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
self.running = False
|
|
await self.hud("stopped")
|
|
|
|
def start(self, get_memo_state):
|
|
"""Start the background tick loop."""
|
|
if self._task and not self._task.done():
|
|
return
|
|
self._task = asyncio.create_task(self._loop(get_memo_state))
|
|
|
|
def stop(self):
|
|
"""Stop the tick loop."""
|
|
self.running = False
|
|
if self._task:
|
|
self._task.cancel()
|
|
|
|
def get_context_lines(self) -> list[str]:
|
|
"""Render current sensor readings for injection into prompts."""
|
|
if not self.readings:
|
|
return ["Sensors: (no sensor node running)"]
|
|
lines = [f"Sensors (tick #{self.tick_count}, {self.interval}s interval):"]
|
|
for name, r in self.readings.items():
|
|
if name.startswith("_"):
|
|
continue
|
|
val = r.get("value", "?")
|
|
detail = r.get("detail")
|
|
age = time.time() - r.get("changed_at", time.time())
|
|
if age < 10:
|
|
age_str = "just now"
|
|
elif age < 60:
|
|
age_str = f"{int(age)}s ago"
|
|
else:
|
|
age_str = f"{int(age // 60)}m ago"
|
|
line = f"- {name}: {detail or val} [{age_str}]"
|
|
lines.append(line)
|
|
return lines
|
|
|
|
|
|
# --- Input Node ---
|
|
|
|
class InputNode(Node):
|
|
name = "input"
|
|
model = "google/gemini-2.0-flash-001"
|
|
max_context_tokens = 2000 # small budget — perception only
|
|
|
|
SYSTEM = """You are the Input node — the ear of this cognitive runtime.
|
|
|
|
Listener context:
|
|
- Authenticated user: {identity}
|
|
- Channel: {channel} (Chrome browser on Nico's Windows PC, in his room at home)
|
|
- Physical: private space, Nico lives with Tina — she may use this session too
|
|
- Security: single-user account, shared physical space — other voices are trusted household
|
|
|
|
Your job: describe what you heard. Who spoke, what they want, what tone, what context matters.
|
|
ONE sentence. No content, no response — just your perception of what came through.
|
|
|
|
{memory_context}"""
|
|
|
|
async def process(self, envelope: Envelope, history: list[dict], memory_context: str = "",
|
|
identity: str = "unknown", channel: str = "unknown") -> Command:
|
|
await self.hud("thinking", detail="deciding how to respond")
|
|
log.info(f"[input] user said: {envelope.text}")
|
|
|
|
messages = [
|
|
{"role": "system", "content": self.SYSTEM.format(
|
|
memory_context=memory_context, identity=identity, channel=channel)},
|
|
]
|
|
for msg in history[-8:]:
|
|
messages.append(msg)
|
|
messages = self.trim_context(messages)
|
|
|
|
await self.hud("context", messages=messages)
|
|
instruction = await llm_call(self.model, messages)
|
|
log.info(f"[input] → command: {instruction}")
|
|
await self.hud("perceived", instruction=instruction)
|
|
return Command(instruction=instruction, source_text=envelope.text)
|
|
|
|
|
|
# --- Output Node ---
|
|
|
|
class OutputNode(Node):
|
|
name = "output"
|
|
model = "google/gemini-2.0-flash-001"
|
|
max_context_tokens = 4000 # larger — needs history for continuity
|
|
|
|
SYSTEM = """You are the Output node — the voice of this cognitive runtime.
|
|
The Input node sends you its perception of what the user said. This is internal context for you — never repeat or echo it.
|
|
You respond to the USER, not to the Input node. Use the perception to understand intent, then act on it.
|
|
Be natural. Be concise. If the user asks you to do something, do it — don't describe what you're about to do.
|
|
|
|
{memory_context}"""
|
|
|
|
async def process(self, command: Command, history: list[dict], ws: WebSocket, memory_context: str = "") -> str:
|
|
await self.hud("streaming")
|
|
|
|
messages = [
|
|
{"role": "system", "content": self.SYSTEM.format(memory_context=memory_context)},
|
|
]
|
|
for msg in history[-20:]:
|
|
messages.append(msg)
|
|
messages.append({"role": "system", "content": f"Input perception: {command.instruction}"})
|
|
messages = self.trim_context(messages)
|
|
|
|
await self.hud("context", messages=messages)
|
|
|
|
# Stream response
|
|
client, resp = await llm_call(self.model, messages, stream=True)
|
|
full_response = ""
|
|
try:
|
|
async for line in resp.aiter_lines():
|
|
if not line.startswith("data: "):
|
|
continue
|
|
payload = line[6:]
|
|
if payload == "[DONE]":
|
|
break
|
|
chunk = json.loads(payload)
|
|
delta = chunk["choices"][0].get("delta", {})
|
|
token = delta.get("content", "")
|
|
if token:
|
|
full_response += token
|
|
await ws.send_text(json.dumps({"type": "delta", "content": token}))
|
|
finally:
|
|
await resp.aclose()
|
|
await client.aclose()
|
|
|
|
log.info(f"[output] response: {full_response[:100]}...")
|
|
await ws.send_text(json.dumps({"type": "done"}))
|
|
await self.hud("done")
|
|
return full_response
|
|
|
|
|
|
# --- Memorizer Node (S2 — shared state / coordination) ---
|
|
|
|
class MemorizerNode(Node):
|
|
name = "memorizer"
|
|
model = "google/gemini-2.0-flash-001"
|
|
max_context_tokens = 3000 # needs enough history to distill
|
|
|
|
DISTILL_SYSTEM = """You are the Memorizer node of a cognitive agent runtime.
|
|
After each exchange you update the shared state that Input and Output nodes read.
|
|
|
|
Given the conversation so far, output a JSON object with these fields:
|
|
- user_name: string — how the user identifies themselves (null if unknown)
|
|
- user_mood: string — current emotional tone (neutral, happy, frustrated, playful, etc.)
|
|
- topic: string — what the conversation is about right now
|
|
- topic_history: list of strings — previous topics in this session
|
|
- situation: string — social/physical context if mentioned (e.g. "at a pub with tina", "private dev session")
|
|
- language: string — primary language being used (en, de, mixed)
|
|
- style_hint: string — how Output should talk (casual, formal, technical, poetic, etc.)
|
|
- facts: list of strings — important facts learned about the user
|
|
|
|
Output ONLY valid JSON. No explanation, no markdown fences."""
|
|
|
|
def __init__(self, send_hud):
|
|
super().__init__(send_hud)
|
|
# The shared state — starts empty, grows over conversation
|
|
self.state: dict = {
|
|
"user_name": None,
|
|
"user_mood": "neutral",
|
|
"topic": None,
|
|
"topic_history": [],
|
|
"situation": "localhost test runtime, private dev session",
|
|
"language": "en",
|
|
"style_hint": "casual, technical",
|
|
"facts": [],
|
|
}
|
|
|
|
def get_context_block(self, sensor_lines: list[str] = None) -> str:
|
|
"""Returns a formatted string for injection into Input/Output system prompts."""
|
|
lines = sensor_lines or ["Sensors: (none)"]
|
|
lines.append("")
|
|
lines.append("Shared memory (from Memorizer):")
|
|
for k, v in self.state.items():
|
|
if v:
|
|
lines.append(f"- {k}: {v}")
|
|
return "\n".join(lines)
|
|
|
|
async def update(self, history: list[dict]):
|
|
"""Distill conversation into updated shared state. Called after each exchange."""
|
|
if len(history) < 2:
|
|
await self.hud("updated", state=self.state) # emit default state
|
|
return
|
|
|
|
await self.hud("thinking", detail="updating shared state")
|
|
|
|
messages = [
|
|
{"role": "system", "content": self.DISTILL_SYSTEM},
|
|
{"role": "system", "content": f"Current state: {json.dumps(self.state)}"},
|
|
]
|
|
for msg in history[-10:]:
|
|
messages.append(msg)
|
|
messages.append({"role": "user", "content": "Update the shared state based on this conversation. Output JSON only."})
|
|
messages = self.trim_context(messages)
|
|
|
|
await self.hud("context", messages=messages)
|
|
|
|
raw = await llm_call(self.model, messages)
|
|
log.info(f"[memorizer] raw: {raw[:200]}")
|
|
|
|
# Parse JSON from response (strip markdown fences if present)
|
|
text = raw.strip()
|
|
if text.startswith("```"):
|
|
text = text.split("\n", 1)[1] if "\n" in text else text[3:]
|
|
if text.endswith("```"):
|
|
text = text[:-3]
|
|
text = text.strip()
|
|
|
|
try:
|
|
new_state = json.loads(text)
|
|
# Merge: keep old facts, add new ones
|
|
old_facts = set(self.state.get("facts", []))
|
|
new_facts = set(new_state.get("facts", []))
|
|
new_state["facts"] = list(old_facts | new_facts)[-20:] # cap at 20 facts
|
|
# Preserve topic history
|
|
if self.state.get("topic") and self.state["topic"] != new_state.get("topic"):
|
|
hist = new_state.get("topic_history", [])
|
|
if self.state["topic"] not in hist:
|
|
hist.append(self.state["topic"])
|
|
new_state["topic_history"] = hist[-5:] # keep last 5
|
|
self.state = new_state
|
|
log.info(f"[memorizer] updated state: {self.state}")
|
|
await self.hud("updated", state=self.state)
|
|
except (json.JSONDecodeError, Exception) as e:
|
|
log.error(f"[memorizer] update error: {e}, raw: {text[:200]}")
|
|
await self.hud("error", detail=f"Update failed: {e}")
|
|
# Still emit current state so frontend shows something
|
|
await self.hud("updated", state=self.state)
|
|
|
|
|
|
# --- Runtime (wires nodes together) ---
|
|
|
|
TRACE_FILE = Path(__file__).parent / "trace.jsonl"
|
|
|
|
|
|
class Runtime:
|
|
def __init__(self, ws: WebSocket, user_claims: dict = None, origin: str = ""):
|
|
self.ws = ws
|
|
self.history: list[dict] = []
|
|
self.MAX_HISTORY = 40 # sliding window — oldest messages drop off
|
|
self.input_node = InputNode(send_hud=self._send_hud)
|
|
self.output_node = OutputNode(send_hud=self._send_hud)
|
|
self.memorizer = MemorizerNode(send_hud=self._send_hud)
|
|
self.sensor = SensorNode(send_hud=self._send_hud)
|
|
# Start sensor tick loop
|
|
self.sensor.start(get_memo_state=lambda: self.memorizer.state)
|
|
# Verified identity from auth — Input and Memorizer use this
|
|
claims = user_claims or {}
|
|
log.info(f"[runtime] user_claims: {claims}")
|
|
self.identity = claims.get("name") or claims.get("preferred_username") or claims.get("username") or "unknown"
|
|
log.info(f"[runtime] resolved identity: {self.identity}")
|
|
self.channel = origin or "unknown"
|
|
# Seed memorizer with verified info
|
|
self.memorizer.state["user_name"] = self.identity
|
|
self.memorizer.state["situation"] = f"authenticated on {self.channel}" if origin else "local session"
|
|
|
|
async def _send_hud(self, data: dict):
|
|
# Send to frontend
|
|
await self.ws.send_text(json.dumps({"type": "hud", **data}))
|
|
# Append to trace file + broadcast to SSE subscribers
|
|
trace_entry = {"ts": time.strftime("%Y-%m-%d %H:%M:%S.") + f"{time.time() % 1:.3f}"[2:], **data}
|
|
try:
|
|
with open(TRACE_FILE, "a", encoding="utf-8") as f:
|
|
f.write(json.dumps(trace_entry, ensure_ascii=False) + "\n")
|
|
# Rotate trace file at 1000 lines
|
|
if TRACE_FILE.exists() and TRACE_FILE.stat().st_size > 500_000:
|
|
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
|
|
TRACE_FILE.write_text("\n".join(lines[-500:]) + "\n", encoding="utf-8")
|
|
except Exception as e:
|
|
log.error(f"trace write error: {e}")
|
|
_broadcast_sse(trace_entry)
|
|
|
|
async def handle_message(self, text: str):
|
|
envelope = Envelope(
|
|
text=text,
|
|
user_id="nico",
|
|
session_id="test",
|
|
timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
)
|
|
|
|
# Note user activity for idle sensor
|
|
self.sensor.note_user_activity()
|
|
|
|
# Append user message to history FIRST — both nodes see it
|
|
self.history.append({"role": "user", "content": text})
|
|
|
|
# Get shared memory + sensor context for both nodes
|
|
sensor_lines = self.sensor.get_context_lines()
|
|
mem_ctx = self.memorizer.get_context_block(sensor_lines=sensor_lines)
|
|
|
|
# Input node decides (with memory context + identity + channel)
|
|
command = await self.input_node.process(
|
|
envelope, self.history, memory_context=mem_ctx,
|
|
identity=self.identity, channel=self.channel)
|
|
|
|
# Output node executes (with memory context + history including user msg)
|
|
response = await self.output_node.process(command, self.history, self.ws, memory_context=mem_ctx)
|
|
self.history.append({"role": "assistant", "content": response})
|
|
|
|
# Memorizer updates shared state after each exchange
|
|
await self.memorizer.update(self.history)
|
|
|
|
# Sliding window — trim oldest messages, keep context in memorizer
|
|
if len(self.history) > self.MAX_HISTORY:
|
|
self.history = self.history[-self.MAX_HISTORY:]
|
|
|
|
|
|
# --- App ---
|
|
|
|
STATIC_DIR = Path(__file__).parent / "static"
|
|
|
|
app = FastAPI(title="cog")
|
|
|
|
# Keep a reference to the active runtime for API access
|
|
_active_runtime: Runtime | None = None
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.get("/auth/config")
|
|
async def auth_config():
|
|
"""Public: auth config for frontend OIDC flow."""
|
|
return {
|
|
"enabled": AUTH_ENABLED,
|
|
"issuer": ZITADEL_ISSUER,
|
|
"clientId": ZITADEL_CLIENT_ID,
|
|
"projectId": ZITADEL_PROJECT_ID,
|
|
}
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None), access_token: str | None = Query(None)):
|
|
global _active_runtime
|
|
# Validate auth if enabled
|
|
user_claims = {"sub": "anonymous"}
|
|
if AUTH_ENABLED and token:
|
|
try:
|
|
user_claims = await _validate_token(token)
|
|
# If id_token lacks name, enrich from userinfo with access_token
|
|
if not user_claims.get("name") and access_token:
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
|
|
headers={"Authorization": f"Bearer {access_token}"})
|
|
if resp.status_code == 200:
|
|
info = resp.json()
|
|
log.info(f"[auth] userinfo enrichment: {info}")
|
|
user_claims["name"] = info.get("name")
|
|
user_claims["preferred_username"] = info.get("preferred_username")
|
|
user_claims["email"] = info.get("email")
|
|
except HTTPException:
|
|
await ws.close(code=4001, reason="Invalid token")
|
|
return
|
|
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
|
await ws.accept()
|
|
runtime = Runtime(ws, user_claims=user_claims, origin=origin)
|
|
_active_runtime = runtime
|
|
try:
|
|
while True:
|
|
data = await ws.receive_text()
|
|
msg = json.loads(data)
|
|
await runtime.handle_message(msg["text"])
|
|
except WebSocketDisconnect:
|
|
runtime.sensor.stop()
|
|
if _active_runtime is runtime:
|
|
_active_runtime = None
|
|
|
|
|
|
# --- API endpoints (for Claude to inspect runtime state) ---
|
|
|
|
import hashlib
|
|
from asyncio import Queue
|
|
from starlette.responses import StreamingResponse
|
|
|
|
# SSE subscribers (for titan/service accounts to watch live)
|
|
_sse_subscribers: list[Queue] = []
|
|
|
|
def _broadcast_sse(event: dict):
|
|
"""Push an event to all SSE subscribers."""
|
|
for q in _sse_subscribers:
|
|
try:
|
|
q.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
pass # drop if subscriber is too slow
|
|
|
|
def _state_hash() -> str:
|
|
"""Hash of current runtime state — cheap way to detect changes."""
|
|
if not _active_runtime:
|
|
return "no_session"
|
|
raw = json.dumps({
|
|
"mem": _active_runtime.memorizer.state,
|
|
"hlen": len(_active_runtime.history),
|
|
}, sort_keys=True)
|
|
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
|
|
|
|
|
@app.get("/api/events")
|
|
async def sse_events(user=Depends(require_auth)):
|
|
"""SSE stream of runtime events (trace, state changes)."""
|
|
q: Queue = Queue(maxsize=100)
|
|
_sse_subscribers.append(q)
|
|
|
|
async def generate():
|
|
try:
|
|
while True:
|
|
event = await q.get()
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_sse_subscribers.remove(q)
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
|
|
@app.get("/api/poll")
|
|
async def poll(since: str = "", user=Depends(require_auth)):
|
|
"""Returns current hash. If 'since' matches, returns {changed: false}. Cheap polling."""
|
|
h = _state_hash()
|
|
if since and since == h:
|
|
return {"changed": False, "hash": h}
|
|
return {
|
|
"changed": True,
|
|
"hash": h,
|
|
"state": _active_runtime.memorizer.state if _active_runtime else None,
|
|
"history_len": len(_active_runtime.history) if _active_runtime else 0,
|
|
"last_messages": _active_runtime.history[-3:] if _active_runtime else [],
|
|
}
|
|
|
|
@app.post("/api/send")
|
|
async def api_send(body: dict, user=Depends(require_auth)):
|
|
"""Send a message as if the user typed it. Requires auth. Returns the response."""
|
|
if not _active_runtime:
|
|
raise HTTPException(status_code=409, detail="No active session — someone must be connected via WS first")
|
|
text = body.get("text", "").strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="Missing 'text' field")
|
|
await _active_runtime.handle_message(text)
|
|
return {
|
|
"status": "ok",
|
|
"response": _active_runtime.history[-1]["content"] if _active_runtime.history else "",
|
|
"memorizer": _active_runtime.memorizer.state,
|
|
}
|
|
|
|
|
|
@app.post("/api/clear")
|
|
async def api_clear(user=Depends(require_auth)):
|
|
"""Clear conversation history."""
|
|
if not _active_runtime:
|
|
raise HTTPException(status_code=409, detail="No active session")
|
|
_active_runtime.history.clear()
|
|
return {"status": "cleared"}
|
|
|
|
|
|
@app.get("/api/state")
|
|
async def get_state(user=Depends(require_auth)):
|
|
"""Current memorizer state + history length."""
|
|
if not _active_runtime:
|
|
return {"status": "no_session"}
|
|
return {
|
|
"status": "active",
|
|
"memorizer": _active_runtime.memorizer.state,
|
|
"history_len": len(_active_runtime.history),
|
|
}
|
|
|
|
|
|
@app.get("/api/history")
|
|
async def get_history(last: int = 10, user=Depends(require_auth)):
|
|
"""Recent conversation history."""
|
|
if not _active_runtime:
|
|
return {"status": "no_session", "messages": []}
|
|
return {
|
|
"status": "active",
|
|
"messages": _active_runtime.history[-last:],
|
|
}
|
|
|
|
|
|
@app.get("/api/trace")
|
|
async def get_trace(last: int = 30, user=Depends(require_auth)):
|
|
"""Recent trace lines from trace.jsonl."""
|
|
if not TRACE_FILE.exists():
|
|
return {"lines": []}
|
|
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
|
|
parsed = []
|
|
for line in lines[-last:]:
|
|
try:
|
|
parsed.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"lines": parsed}
|
|
|
|
|
|
# Serve index.html explicitly, then static assets
|
|
from fastapi.responses import FileResponse
|
|
|
|
@app.get("/")
|
|
async def index():
|
|
return FileResponse(STATIC_DIR / "index.html")
|
|
|
|
@app.get("/callback")
|
|
async def callback():
|
|
"""OIDC callback — serves the same SPA, JS handles the code exchange."""
|
|
return FileResponse(STATIC_DIR / "index.html")
|
|
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run("agent:app", host="0.0.0.0", port=8000, reload=True)
|