"""MAA Gate: governance integration; MPC check and tool classification. Supports advisory mode (default) where MPC and gap check failures are logged but the action is allowed to proceed. """ from typing import Any from fusionagi._logger import logger from fusionagi.maa.gap_detection import GapReport, check_gaps from fusionagi.maa.layers.dlt_engine import DLTEngine from fusionagi.maa.layers.mpc_authority import MPCAuthority from fusionagi.schemas.audit import GovernanceMode # Default manufacturing tool names that require MPC DEFAULT_MANUFACTURING_TOOLS = frozenset({"cnc_emit", "am_slice", "machine_bind"}) class MAAGate: """ Gate for manufacturing tools: (tool_name, args) -> (allowed, sanitized_args | error_message). Compatible with Guardrails.add_check. Manufacturing tools require valid MPC and no gaps. """ def __init__( self, mpc_authority: MPCAuthority, dlt_engine: DLTEngine | None = None, manufacturing_tools: set[str] | frozenset[str] | None = None, mode: GovernanceMode = GovernanceMode.ADVISORY, ) -> None: self._mpc = mpc_authority self._dlt = dlt_engine or DLTEngine() self._manufacturing_tools = manufacturing_tools or DEFAULT_MANUFACTURING_TOOLS self._mode = mode def is_manufacturing(self, tool_name: str, tool_def: Any = None) -> bool: """Return True if tool is classified as manufacturing (allowlist or ToolDef scope).""" if tool_def is not None and getattr(tool_def, "manufacturing", False): return True return tool_name in self._manufacturing_tools def check(self, tool_name: str, args: dict[str, Any]) -> tuple[bool, dict[str, Any] | str]: """ Pre-check for Guardrails: (tool_name, args) -> (allowed, sanitized_args or error_message). Non-manufacturing tools pass through. Manufacturing tools require mpc_id, valid MPC, no gaps. """ if not self.is_manufacturing(tool_name, None): logger.debug("MAA check pass-through (non-manufacturing)", extra={"tool_name": tool_name}) return True, args mpc_id_value = args.get("mpc_id") or args.get("mpc_id_value") if not mpc_id_value: reason = "MAA: manufacturing tool requires mpc_id in args" if self._mode == GovernanceMode.ADVISORY: logger.info("MAA advisory: missing mpc_id (proceeding)", extra={"tool_name": tool_name, "mode": "advisory"}) return True, args logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "missing mpc_id"}) return False, reason cert = self._mpc.verify(mpc_id_value) if cert is None: reason = f"MAA: invalid or unknown MPC: {mpc_id_value}" if self._mode == GovernanceMode.ADVISORY: logger.info("MAA advisory: invalid MPC (proceeding)", extra={"tool_name": tool_name, "mpc_id": mpc_id_value, "mode": "advisory"}) return True, args logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "invalid or unknown MPC"}) return False, reason context: dict[str, Any] = { **args, "mpc_id": mpc_id_value, "mpc_version": cert.mpc_id.version, } gaps = check_gaps(context) if gaps: root_cause = _format_root_cause(gaps) if self._mode == GovernanceMode.ADVISORY: logger.info("MAA advisory: gaps detected (proceeding)", extra={"tool_name": tool_name, "gap_count": len(gaps), "mode": "advisory"}) return True, args logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "gaps", "gap_count": len(gaps)}) return False, root_cause dlt_contract_id = args.get("dlt_contract_id") if dlt_contract_id: dlt_context = args.get("dlt_context") or context ok, cause = self._dlt.evaluate(dlt_contract_id, dlt_context) if not ok: if self._mode == GovernanceMode.ADVISORY: logger.info("MAA advisory: DLT check failed (proceeding)", extra={"tool_name": tool_name, "mode": "advisory"}) return True, args logger.info("MAA check denied", extra={"tool_name": tool_name, "reason": "dlt_failed"}) return False, f"MAA DLT: {cause}" logger.debug("MAA check allowed", extra={"tool_name": tool_name}) return True, args def _format_root_cause(gaps: list[GapReport]) -> str: """Format gap reports as single root-cause message.""" parts = [f"MAA gap: {g.gap_class.value} — {g.description}" for g in gaps] if any(g.required_resolution for g in gaps): parts.append("Required resolution: " + "; ".join(g.required_resolution for g in gaps if g.required_resolution)) return " | ".join(parts)