"""OpenAI LLM adapter with error handling and retry logic.""" import time from typing import Any from fusionagi._logger import logger from fusionagi.adapters.base import LLMAdapter class OpenAIAdapterError(Exception): """Base exception for OpenAI adapter errors.""" pass class OpenAIRateLimitError(OpenAIAdapterError): """Raised when rate limited by OpenAI API.""" pass class OpenAIAuthenticationError(OpenAIAdapterError): """Raised when authentication fails.""" pass class OpenAIAdapter(LLMAdapter): """ OpenAI API adapter with retry logic and error handling. Requires openai package and OPENAI_API_KEY. Features: - Automatic retry with exponential backoff for transient errors - Proper error classification (rate limits, auth errors, etc.) - Structured output support via complete_structured() """ def __init__( self, model: str = "gpt-4o-mini", api_key: str | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_multiplier: float = 2.0, max_retry_delay: float = 30.0, **client_kwargs: Any, ) -> None: """ Initialize the OpenAI adapter. Args: model: Default model to use (e.g., "gpt-4o-mini", "gpt-4o"). api_key: OpenAI API key. If None, uses OPENAI_API_KEY env var. max_retries: Maximum number of retry attempts for transient errors. retry_delay: Initial delay between retries in seconds. retry_multiplier: Multiplier for exponential backoff. max_retry_delay: Maximum delay between retries. **client_kwargs: Additional arguments passed to OpenAI client. """ self._model = model self._api_key = api_key self._max_retries = max_retries self._retry_delay = retry_delay self._retry_multiplier = retry_multiplier self._max_retry_delay = max_retry_delay self._client_kwargs = client_kwargs self._client: Any = None self._openai_module: Any = None def _get_client(self) -> Any: if self._client is None: try: import openai self._openai_module = openai self._client = openai.OpenAI(api_key=self._api_key, **self._client_kwargs) except ImportError as e: raise ImportError("Install with: pip install fusionagi[openai]") from e return self._client def _is_retryable_error(self, error: Exception) -> bool: """Check if an error is retryable (transient).""" if self._openai_module is None: return False # Rate limit errors are retryable if hasattr(self._openai_module, "RateLimitError"): if isinstance(error, self._openai_module.RateLimitError): return True # API connection errors are retryable if hasattr(self._openai_module, "APIConnectionError"): if isinstance(error, self._openai_module.APIConnectionError): return True # Internal server errors are retryable if hasattr(self._openai_module, "InternalServerError"): if isinstance(error, self._openai_module.InternalServerError): return True # Timeout errors are retryable if hasattr(self._openai_module, "APITimeoutError"): if isinstance(error, self._openai_module.APITimeoutError): return True return False def _classify_error(self, error: Exception) -> Exception: """Convert OpenAI exceptions to adapter exceptions.""" if self._openai_module is None: return OpenAIAdapterError(str(error)) if hasattr(self._openai_module, "RateLimitError"): if isinstance(error, self._openai_module.RateLimitError): return OpenAIRateLimitError(str(error)) if hasattr(self._openai_module, "AuthenticationError"): if isinstance(error, self._openai_module.AuthenticationError): return OpenAIAuthenticationError(str(error)) return OpenAIAdapterError(str(error)) def complete( self, messages: list[dict[str, str]], **kwargs: Any, ) -> str: """ Call OpenAI chat completion with retry logic. Args: messages: List of message dicts with 'role' and 'content'. **kwargs: Additional arguments for the API call (e.g., temperature). Returns: The assistant's response content. Raises: OpenAIAuthenticationError: If authentication fails. OpenAIRateLimitError: If rate limited after all retries. OpenAIAdapterError: For other API errors after all retries. """ # Validate messages format if not messages: logger.warning("OpenAI complete called with empty messages") return "" for i, msg in enumerate(messages): if not isinstance(msg, dict): raise ValueError(f"Message {i} must be a dict, got {type(msg).__name__}") if "role" not in msg: raise ValueError(f"Message {i} missing 'role' key") if "content" not in msg: raise ValueError(f"Message {i} missing 'content' key") client = self._get_client() model = kwargs.get("model", self._model) call_kwargs = {**kwargs, "model": model} last_error: Exception | None = None delay = self._retry_delay for attempt in range(self._max_retries + 1): try: resp = client.chat.completions.create( messages=messages, **call_kwargs, ) choice = resp.choices[0] if resp.choices else None if choice and choice.message and choice.message.content: return str(choice.message.content) logger.debug("OpenAI empty response", extra={"model": model, "attempt": attempt}) return "" except Exception as e: last_error = e # Don't retry authentication errors if self._openai_module and hasattr(self._openai_module, "AuthenticationError"): if isinstance(e, self._openai_module.AuthenticationError): logger.error("OpenAI authentication failed", extra={"error": str(e)}) raise OpenAIAuthenticationError(str(e)) from e # Check if retryable if not self._is_retryable_error(e): logger.error( "OpenAI non-retryable error", extra={"error": str(e), "error_type": type(e).__name__}, ) raise self._classify_error(e) from e # Log retry attempt if attempt < self._max_retries: logger.warning( "OpenAI retryable error, retrying", extra={ "error": str(e), "attempt": attempt + 1, "max_retries": self._max_retries, "delay": delay, }, ) time.sleep(delay) delay = min(delay * self._retry_multiplier, self._max_retry_delay) # All retries exhausted logger.error( "OpenAI all retries exhausted", extra={"error": str(last_error), "attempts": self._max_retries + 1}, ) if last_error is not None: raise self._classify_error(last_error) from last_error raise OpenAIAdapterError("All retries exhausted with unknown error") async def acomplete( self, messages: list[dict[str, str]], **kwargs: Any, ) -> str: """Async version of complete using OpenAI's async client. Args: messages: List of message dicts with 'role' and 'content'. **kwargs: Additional arguments for the API call. Returns: The assistant's response content. """ import asyncio if not messages: return "" try: import openai except ImportError as e: raise ImportError("Install with: pip install fusionagi[openai]") from e async_client = openai.AsyncOpenAI(api_key=self._api_key, **self._client_kwargs) model = kwargs.pop("model", self._model) last_error: Exception | None = None delay = self._retry_delay for attempt in range(self._max_retries + 1): try: response = await async_client.chat.completions.create( model=model, messages=messages, **kwargs # type: ignore[arg-type] ) content = response.choices[0].message.content or "" return content except Exception as e: last_error = e if not self._is_retryable_error(e) or attempt == self._max_retries: break logger.warning( "OpenAI async retry", extra={"attempt": attempt + 1, "error": str(e), "delay": delay}, ) await asyncio.sleep(delay) delay = min(delay * self._retry_multiplier, self._max_retry_delay) if last_error is not None: raise self._classify_error(last_error) from last_error raise OpenAIAdapterError("All retries exhausted") def complete_structured( self, messages: list[dict[str, str]], schema: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """ Call OpenAI with JSON mode for structured output. Args: messages: List of message dicts with 'role' and 'content'. schema: Optional JSON schema for response validation (informational). **kwargs: Additional arguments for the API call. Returns: Parsed JSON response or None if parsing fails. """ import json # Enable JSON mode call_kwargs = {**kwargs, "response_format": {"type": "json_object"}} # Add schema hint to system message if provided if schema and messages: schema_hint = f"\n\nRespond with JSON matching this schema: {json.dumps(schema)}" if messages[0].get("role") == "system": messages = [ {**messages[0], "content": messages[0]["content"] + schema_hint}, *messages[1:], ] else: messages = [ {"role": "system", "content": f"You must respond with valid JSON.{schema_hint}"}, *messages, ] raw = self.complete(messages, **call_kwargs) if not raw: return None try: return json.loads(raw) except json.JSONDecodeError as e: logger.warning( "OpenAI JSON parse failed", extra={"error": str(e), "raw_response": raw[:200]}, ) return None