"""Standalone MCP SSE app — proxies tool calls to assay-runtime.""" import json import logging import os from pathlib import Path from dotenv import load_dotenv load_dotenv(Path(__file__).parent.parent / ".env") import httpx from fastapi import FastAPI, Request, Depends from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.types import TextContent, Tool logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S") log = logging.getLogger("mcp-proxy") # Config RUNTIME_URL = os.environ.get("RUNTIME_URL", "http://assay-runtime") SERVICE_TOKENS = set(filter(None, os.environ.get("SERVICE_TOKENS", "").split(","))) SERVICE_TOKEN = os.environ.get("SERVICE_TOKENS", "").split(",")[0] if os.environ.get("SERVICE_TOKENS") else "" app = FastAPI(title="assay-mcp") _security = HTTPBearer() async def require_auth(creds: HTTPAuthorizationCredentials = Depends(_security)): if creds.credentials not in SERVICE_TOKENS: from fastapi import HTTPException raise HTTPException(status_code=401, detail="Invalid token") return {"sub": "service", "source": "service_token"} @app.get("/health") async def health(): return {"status": "ok", "service": "mcp-proxy"} # --- MCP Server --- mcp_server = Server("assay") _mcp_transport = SseServerTransport("/mcp/messages/") async def _proxy_get(path: str, params: dict = None) -> dict: """GET request to runtime.""" try: async with httpx.AsyncClient(timeout=30) as client: resp = await client.get( f"{RUNTIME_URL}{path}", params=params, headers={"Authorization": f"Bearer {SERVICE_TOKEN}"}, ) if resp.status_code == 200: return resp.json() try: return {"error": resp.json().get("detail", resp.text)} except Exception: return {"error": resp.text} except Exception as e: return {"error": f"Runtime unreachable: {e}"} async def _proxy_post(path: str, body: dict = None) -> dict: """POST request to runtime.""" try: async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{RUNTIME_URL}{path}", json=body or {}, headers={"Authorization": f"Bearer {SERVICE_TOKEN}"}, ) if resp.status_code == 200: return resp.json() try: return {"error": resp.json().get("detail", resp.text)} except Exception: return {"error": resp.text} except Exception as e: return {"error": f"Runtime unreachable: {e}"} @mcp_server.list_tools() async def list_tools(): return [ Tool(name="assay_send", description="Send a message to the cognitive agent and get a response.", inputSchema={"type": "object", "properties": { "text": {"type": "string", "description": "Message text to send"}, "database": {"type": "string", "description": "Optional: database name for query_db context"}, }, "required": ["text"]}), Tool(name="assay_trace", description="Get recent trace events from the pipeline (HUD events, tool calls, audit).", inputSchema={"type": "object", "properties": { "last": {"type": "integer", "description": "Number of recent events (default 20)", "default": 20}, "filter": {"type": "string", "description": "Comma-separated event types to filter (e.g. 'tool_call,controls')"}, }}), Tool(name="assay_history", description="Get recent chat messages from the active session.", inputSchema={"type": "object", "properties": { "last": {"type": "integer", "description": "Number of recent messages (default 20)", "default": 20}, }}), Tool(name="assay_state", description="Get the current memorizer state (mood, topic, language, facts).", inputSchema={"type": "object", "properties": {}}), Tool(name="assay_clear", description="Clear the active session (history, state, controls).", inputSchema={"type": "object", "properties": {}}), Tool(name="assay_graph", description="Get the active graph definition (nodes, edges, description).", inputSchema={"type": "object", "properties": {}}), Tool(name="assay_graph_list", description="List all available graph definitions.", inputSchema={"type": "object", "properties": {}}), Tool(name="assay_graph_switch", description="Switch the active graph for new sessions.", inputSchema={"type": "object", "properties": { "name": {"type": "string", "description": "Graph name to switch to"}, }, "required": ["name"]}), ] @mcp_server.call_tool() async def call_tool(name: str, arguments: dict): if name == "assay_send": text = arguments.get("text", "") if not text: return [TextContent(type="text", text="ERROR: Missing 'text' argument.")] # Step 1: check runtime is ready check = await _proxy_post("/api/send/check") if "error" in check: return [TextContent(type="text", text=f"ERROR: {check['error']}")] if not check.get("ready"): return [TextContent(type="text", text=f"ERROR: {check.get('reason', 'unknown')}: {check.get('detail', '')}")] # Step 2: queue message send = await _proxy_post("/api/send", {"text": text}) if "error" in send: return [TextContent(type="text", text=f"ERROR: {send['error']}")] msg_id = send.get("id", "") # Step 3: poll for result (max 30s) import asyncio for _ in range(60): await asyncio.sleep(0.5) result = await _proxy_get("/api/result") if "error" in result: return [TextContent(type="text", text=f"ERROR: {result['error']}")] status = result.get("status", "") if status == "done": return [TextContent(type="text", text=result.get("response", "[no response]"))] if status == "error": return [TextContent(type="text", text=f"ERROR: {result.get('detail', 'pipeline failed')}")] return [TextContent(type="text", text="ERROR: Pipeline timeout (30s)")] elif name == "assay_trace": last = arguments.get("last", 20) event_filter = arguments.get("filter", "") params = {"last": last} if event_filter: params["filter"] = event_filter result = await _proxy_get("/api/trace", params) if "error" in result: return [TextContent(type="text", text=f"ERROR: {result['error']}")] # Format trace events compactly events = result.get("lines", []) lines = [] for e in events: node = e.get("node", "?") event = e.get("event", "?") detail = e.get("detail", "") line = f"{node:12s} {event:20s} {detail}" lines.append(line.rstrip()) return [TextContent(type="text", text="\n".join(lines) if lines else "(no events)")] elif name == "assay_history": last = arguments.get("last", 20) result = await _proxy_get("/api/history", {"last": last}) if "error" in result: return [TextContent(type="text", text=f"ERROR: {result['error']}")] return [TextContent(type="text", text=json.dumps(result.get("messages", []), indent=2))] elif name == "assay_state": result = await _proxy_get("/api/state") if "error" in result: return [TextContent(type="text", text=f"ERROR: {result['error']}")] return [TextContent(type="text", text=json.dumps(result, indent=2))] elif name == "assay_clear": result = await _proxy_post("/api/clear") if "error" in result: return [TextContent(type="text", text=f"ERROR: {result['error']}")] return [TextContent(type="text", text="Session cleared.")] elif name == "assay_graph": result = await _proxy_get("/api/graph/active") if "error" in result: return [TextContent(type="text", text=f"ERROR: {result['error']}")] return [TextContent(type="text", text=json.dumps(result, indent=2))] elif name == "assay_graph_list": result = await _proxy_get("/api/graph/list") if "error" in result: return [TextContent(type="text", text=f"ERROR: {result['error']}")] return [TextContent(type="text", text=json.dumps(result.get("graphs", []), indent=2))] elif name == "assay_graph_switch": gname = arguments.get("name", "") if not gname: return [TextContent(type="text", text="ERROR: Missing 'name' argument.")] result = await _proxy_post("/api/graph/switch", {"name": gname}) if "error" in result: return [TextContent(type="text", text=f"ERROR: {result['error']}")] return [TextContent(type="text", text=f"Switched to graph '{result.get('name', gname)}'. New sessions will use this graph.")] else: return [TextContent(type="text", text=f"Unknown tool: {name}")] # Mount MCP SSE endpoints @app.get("/mcp/sse") async def mcp_sse(request: Request, user=Depends(require_auth)): async with _mcp_transport.connect_sse(request.scope, request.receive, request._send) as streams: await mcp_server.run(streams[0], streams[1], mcp_server.create_initialization_options()) @app.post("/mcp/messages/") async def mcp_messages(request: Request, user=Depends(require_auth)): await _mcp_transport.handle_post_message(request.scope, request.receive, request._send)