"""Shared test harness for node-level tests.""" import asyncio import json import os import sys import time from dataclasses import dataclass, field from pathlib import Path import httpx # Add parent to path so we can import agent sys.path.insert(0, str(Path(__file__).parent.parent)) from agent.types import Envelope, Command, InputAnalysis, ThoughtResult class HudCapture: """Mock send_hud that captures all HUD events for inspection.""" def __init__(self): self.events: list[dict] = [] async def __call__(self, data: dict): self.events.append(data) def find(self, event: str) -> list[dict]: return [e for e in self.events if e.get("event") == event] def has(self, event: str) -> bool: return any(e.get("event") == event for e in self.events) def last(self) -> dict: return self.events[-1] if self.events else {} def clear(self): self.events.clear() class MockWebSocket: """Mock WebSocket that captures sent messages.""" def __init__(self): self.sent: list[str] = [] self.readyState = 1 async def send_text(self, text: str): self.sent.append(text) def get_messages(self) -> list[dict]: return [json.loads(s) for s in self.sent] def get_deltas(self) -> str: """Reconstruct streamed text from delta messages.""" return "".join( json.loads(s).get("content", "") for s in self.sent if '"type": "delta"' in s or '"type":"delta"' in s ) def make_envelope(text: str, user_id: str = "bob") -> Envelope: return Envelope(text=text, user_id=user_id, session_id="test", timestamp=time.strftime("%Y-%m-%d %H:%M:%S")) def make_command(intent: str = "request", topic: str = "", text: str = "", complexity: str = "simple", tone: str = "casual", language: str = "en", who: str = "bob") -> Command: return Command( analysis=InputAnalysis( who=who, language=language, intent=intent, topic=topic, tone=tone, complexity=complexity, ), source_text=text or topic, ) def make_history(messages: list[tuple[str, str]] = None) -> list[dict]: """Create history from (role, content) tuples.""" if not messages: return [] return [{"role": r, "content": c} for r, c in messages] @dataclass class NodeTestResult: name: str passed: bool detail: str = "" elapsed_ms: int = 0 def run_async(coro): """Run an async function synchronously.""" return asyncio.get_event_loop().run_until_complete(coro) class NodeTestRunner: """Collects and runs node-level tests. Optionally reports to frontend.""" def __init__(self, report_url: str = None, token: str = None): self.results: list[NodeTestResult] = [] self.report_url = report_url or os.environ.get("COG_TEST_URL") self.token = token or os.environ.get("COG_TEST_TOKEN", "") self._suite = "" def _report(self, event: str, **data): """POST test status to frontend. Fire-and-forget, never blocks tests.""" if not self.report_url: return try: httpx.post( f"{self.report_url}/api/test/status", json={"event": event, **data}, headers={"Authorization": f"Bearer {self.token}"}, timeout=3, ) except Exception: pass def start_suite(self, name: str, count: int = 0): """Call before a group of tests.""" self._suite = name self._report("suite_start", suite=name, count=count) def end_suite(self): """Call after a group of tests.""" self._report("suite_end") self._suite = "" def test(self, name: str, coro): """Run a single async test, catch and record result.""" full_name = f"{self._suite}: {name}" if self._suite else name t0 = time.time() try: run_async(coro) elapsed = int((time.time() - t0) * 1000) self.results.append(NodeTestResult(name=full_name, passed=True, elapsed_ms=elapsed)) print(f" OK {name} ({elapsed}ms)") self._report("step_result", result={ "step": full_name, "check": name, "status": "PASS", "elapsed_ms": elapsed, }) except AssertionError as e: elapsed = int((time.time() - t0) * 1000) self.results.append(NodeTestResult(name=full_name, passed=False, detail=str(e), elapsed_ms=elapsed)) print(f" FAIL {name} ({elapsed}ms)") print(f" {e}") self._report("step_result", result={ "step": full_name, "check": name, "status": "FAIL", "detail": str(e)[:200], "elapsed_ms": elapsed, }) except Exception as e: elapsed = int((time.time() - t0) * 1000) self.results.append(NodeTestResult(name=full_name, passed=False, detail=f"ERROR: {e}", elapsed_ms=elapsed)) print(f" ERR {name} ({elapsed}ms)") print(f" {e}") self._report("step_result", result={ "step": full_name, "check": name, "status": "FAIL", "detail": f"ERROR: {str(e)[:200]}", "elapsed_ms": elapsed, }) def summary(self) -> tuple[int, int]: passed = sum(1 for r in self.results if r.passed) failed = len(self.results) - passed return passed, failed