"""TensorFlow adapter: local model inference via TF/Keras with TensorCore. Requires: pip install fusionagi[gpu] Provides LLMAdapter-compatible interface for locally-hosted TensorFlow/Keras models. Supports TensorCore mixed-precision, XLA compilation, and GPU memory management. """ from __future__ import annotations import json from typing import Any from fusionagi._logger import logger from fusionagi.adapters.base import LLMAdapter try: import numpy as np import tensorflow as tf except ImportError as e: raise ImportError( "TensorFlow is required for TensorFlowAdapter. " "Install with: pip install fusionagi[gpu]" ) from e class TensorFlowAdapter(LLMAdapter): """LLM adapter for local TensorFlow/Keras model inference. Loads a saved Keras model or TF SavedModel and runs inference with TensorCore acceleration when available. Args: model_path: Path to a saved Keras model (.keras) or SavedModel directory. tokenizer: Optional tokenizer callable (text -> token IDs). max_length: Maximum sequence length for generation. temperature: Sampling temperature. mixed_precision: Enable FP16 mixed-precision for TensorCore. """ def __init__( self, model_path: str | None = None, model: Any | None = None, tokenizer: Any | None = None, max_length: int = 512, temperature: float = 0.7, mixed_precision: bool = False, ) -> None: self._model: Any = None self._tokenizer = tokenizer self._max_length = max_length self._temperature = temperature self._model_path = model_path if mixed_precision: try: tf.keras.mixed_precision.set_global_policy("mixed_float16") logger.info("TensorFlowAdapter: TensorCore mixed-precision enabled") except Exception: logger.warning("TensorFlowAdapter: mixed-precision not available") if model is not None: self._model = model logger.info("TensorFlowAdapter initialized with provided model") elif model_path: self._load_model(model_path) else: logger.info( "TensorFlowAdapter initialized without model " "(will use embedding-based synthesis)" ) def _load_model(self, path: str) -> None: """Load a TF SavedModel or Keras model from disk.""" try: self._model = tf.saved_model.load(path) logger.info("TensorFlowAdapter: loaded SavedModel", extra={"path": path}) except Exception: try: self._model = tf.keras.models.load_model(path) logger.info("TensorFlowAdapter: loaded Keras model", extra={"path": path}) except Exception: logger.warning( "TensorFlowAdapter: no model loaded; " "falling back to embedding synthesis", extra={"path": path}, ) def complete( self, messages: list[dict[str, str]], **kwargs: Any, ) -> str: """Generate completion using the loaded TF model. If no model is loaded, falls back to embedding-based synthesis that uses GPU-accelerated similarity scoring. Args: messages: List of message dicts with 'role' and 'content'. **kwargs: Additional parameters (temperature, max_length). Returns: Generated response text. """ if self._model is not None and self._tokenizer is not None: return self._model_inference(messages, **kwargs) return self._embedding_synthesis(messages) def complete_structured( self, messages: list[dict[str, str]], schema: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Attempt structured JSON output from the model. Falls back to parsing the raw completion if the model doesn't natively support structured output. """ raw = self.complete(messages, **kwargs) try: return json.loads(raw) except (json.JSONDecodeError, TypeError): return None def _model_inference( self, messages: list[dict[str, str]], **kwargs: Any, ) -> str: """Run inference through the loaded TF/Keras model.""" prompt = self._messages_to_prompt(messages) temperature = kwargs.get("temperature", self._temperature) max_length = kwargs.get("max_length", self._max_length) tokenizer = self._tokenizer assert tokenizer is not None tokens = tokenizer(prompt) if isinstance(tokens, (list, np.ndarray)): input_tensor = tf.constant([tokens[:max_length]], dtype=tf.int32) else: input_tensor = tokens try: if hasattr(self._model, "generate"): output = self._model.generate( input_tensor, max_length=max_length, temperature=temperature, ) elif hasattr(self._model, "predict"): output = self._model.predict(input_tensor) elif callable(self._model): output = self._model(input_tensor) else: logger.warning("TensorFlowAdapter: model has no callable interface") return self._embedding_synthesis(messages) if isinstance(output, tf.Tensor): output = output.numpy() if hasattr(output, "tolist"): output = output.tolist() if isinstance(output, list) and output: if isinstance(output[0], list): output = output[0] if isinstance(output[0], (int, float)): if tokenizer and hasattr(tokenizer, "decode"): return str(tokenizer.decode(output)) return str(output) # type: ignore[no-any-return] except Exception as e: logger.warning( "TensorFlowAdapter: model inference failed, using synthesis", extra={"error": str(e)}, ) return self._embedding_synthesis(messages) def _embedding_synthesis(self, messages: list[dict[str, str]]) -> str: """Fallback: synthesize response using GPU-accelerated embeddings. Embeds message content and produces a summary based on semantic similarity between parts. """ content_parts: list[str] = [] for msg in messages: content = msg.get("content", "") if isinstance(content, str) and content.strip(): content_parts.append(content.strip()) if not content_parts: return "" from fusionagi.gpu.backend import get_backend be = get_backend() embeddings = be.embed_texts(content_parts) emb_np = be.to_numpy(embeddings) mean_emb = np.mean(emb_np, axis=0, keepdims=True) sims = be.to_numpy( be.cosine_similarity_matrix(be.from_numpy(mean_emb), embeddings) )[0] ranked_indices = np.argsort(sims)[::-1] summary_parts: list[str] = [] for idx in ranked_indices[:5]: part = content_parts[idx] summary_parts.append(part[:300]) return "\n\n".join(summary_parts) @staticmethod def _messages_to_prompt(messages: list[dict[str, str]]) -> str: """Convert message list to a flat prompt string.""" parts: list[str] = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") parts.append(f"<|{role}|>\n{content}") return "\n".join(parts) def device_summary(self) -> dict[str, Any]: """Return device and model information.""" gpus = tf.config.list_physical_devices("GPU") return { "adapter": "tensorflow", "model_path": self._model_path, "has_model": self._model is not None, "has_tokenizer": self._tokenizer is not None, "gpu_count": len(gpus), "tf_version": tf.__version__, }