"""Async HTTP client for Anthropic-compatible AI API.""" import json import time import httpx from dataclasses import dataclass, field from typing import Optional @dataclass class StreamingMetrics: """Metrics collected during streaming response.""" ttft_ms: float = 0.0 token_timestamps: list = field(default_factory=list) total_tokens: int = 0 tps: float = 0.0 class AIClient: """Async client for Anthropic-compatible AI API.""" def __init__(self, base_url: str, api_key: str, model: str, timeout: float = 60, anthropic_version: str = "2023-06-01"): self.base_url = base_url.rstrip('/') self.api_key = api_key self.model = model self.timeout = timeout self.anthropic_version = anthropic_version self._client: Optional[httpx.AsyncClient] = None async def __aenter__(self): self._client = httpx.AsyncClient( timeout=httpx.Timeout(self.timeout, connect=10.0), http2=True, follow_redirects=True, ) return self async def __aexit__(self, *args): if self._client: await self._client.aclose() self._client = None def _get_headers(self) -> dict: return { "x-api-key": self.api_key, "anthropic-version": self.anthropic_version, "content-type": "application/json", } def _get_url(self) -> str: return f"{self.base_url}/v1/messages?beta=true" def _build_body(self, prompt: str, max_tokens: int = 1024, system: str = None, temperature: float = None) -> dict: body = { "model": self.model, "max_tokens": max_tokens, "messages": [{"role": "user", "content": prompt}], } if system: body["system"] = system if temperature is not None: body["temperature"] = temperature return body async def send_message(self, prompt: str, max_tokens: int = 1024, system: str = None, temperature: float = None ) -> tuple: """ Send a non-streaming message. Returns: (response_text, latency_ms, response_headers) """ if not self._client: raise RuntimeError("Client not initialized. Use 'async with' context.") body = self._build_body(prompt, max_tokens, system, temperature) start = time.perf_counter() response = await self._client.post( self._get_url(), headers=self._get_headers(), json=body, ) latency_ms = (time.perf_counter() - start) * 1000 response.raise_for_status() data = response.json() # Extract text from response text = "" if "content" in data and len(data["content"]) > 0: text = data["content"][0].get("text", "") # Collect headers headers = dict(response.headers) return text, latency_ms, headers async def send_message_streaming(self, prompt: str, max_tokens: int = 1024, system: str = None, temperature: float = None ) -> tuple: """ Send a streaming message using SSE. Returns: (full_text, streaming_metrics, response_headers) """ if not self._client: raise RuntimeError("Client not initialized. Use 'async with' context.") body = self._build_body(prompt, max_tokens, system, temperature) body["stream"] = True metrics = StreamingMetrics() full_text = "" response_headers = {} start = time.perf_counter() first_token_received = False async with self._client.stream( "POST", self._get_url(), headers=self._get_headers(), json=body, ) as response: response.raise_for_status() response_headers = dict(response.headers) buffer = "" async for chunk in response.aiter_text(): buffer += chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) line = line.strip() if not line or line.startswith(":"): continue if line.startswith("data: "): data_str = line[6:] if data_str.strip() == "[DONE]": continue try: event_data = json.loads(data_str) except (json.JSONDecodeError, ValueError): continue event_type = event_data.get("type", "") if event_type == "content_block_delta": delta = event_data.get("delta", {}) text_chunk = delta.get("text", "") if text_chunk: now = time.perf_counter() if not first_token_received: metrics.ttft_ms = (now - start) * 1000 first_token_received = True metrics.token_timestamps.append(now - start) metrics.total_tokens += 1 full_text += text_chunk elapsed = time.perf_counter() - start if metrics.total_tokens > 0 and elapsed > 0: if len(metrics.token_timestamps) > 1: generation_time = metrics.token_timestamps[-1] - metrics.token_timestamps[0] if generation_time > 0: metrics.tps = (metrics.total_tokens - 1) / generation_time else: metrics.tps = metrics.total_tokens / elapsed else: metrics.tps = metrics.total_tokens / elapsed return full_text, metrics, response_headers