"""Guardrails: pre/post checks for tool calls (block paths, sanitize inputs).""" import re from typing import Any from pydantic import BaseModel, Field from fusionagi._logger import logger class PreCheckResult(BaseModel): """Result of a guardrails pre-check: allowed, optional sanitized args, optional error message.""" allowed: bool = Field(..., description="Whether the call is allowed") sanitized_args: dict[str, Any] | None = Field(default=None, description="Args to use if allowed and sanitized") error_message: str | None = Field(default=None, description="Reason for denial if not allowed") class Guardrails: """Pre/post checks for tool invocations.""" def __init__(self) -> None: self._blocked_paths: list[str] = [] self._blocked_patterns: list[re.Pattern[str]] = [] self._custom_checks: list[Any] = [] def block_path_prefix(self, prefix: str) -> None: """Block any file path starting with this prefix.""" self._blocked_paths.append(prefix.rstrip("/")) def block_path_pattern(self, pattern: str) -> None: """Block paths matching this regex.""" self._blocked_patterns.append(re.compile(pattern)) def add_check(self, check: Any) -> None: """ Add a custom pre-check. Check receives (tool_name, args); must not mutate caller's args. Returns (allowed, sanitized_args or error_message): (True, dict) or (True, None) or (False, str). Returned sanitized_args are used for subsequent checks and invocation. """ self._custom_checks.append(check) def pre_check(self, tool_name: str, args: dict[str, Any]) -> PreCheckResult: """Run all pre-checks. Returns PreCheckResult (allowed, sanitized_args, error_message).""" args = dict(args) # Copy to avoid mutating caller's args for key in ("path", "file_path"): if key in args and isinstance(args[key], str): path = args[key] for prefix in self._blocked_paths: if path.startswith(prefix) or path.startswith(prefix + "/"): reason = "Blocked path prefix: " + prefix logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason}) return PreCheckResult(allowed=False, error_message=reason) for pat in self._blocked_patterns: if pat.search(path): reason = "Blocked path pattern" logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason}) return PreCheckResult(allowed=False, error_message=reason) for check in self._custom_checks: allowed, result = check(tool_name, args) if not allowed: reason = result if isinstance(result, str) else "Check failed" logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason}) return PreCheckResult(allowed=False, error_message=reason) if isinstance(result, dict): args = result return PreCheckResult(allowed=True, sanitized_args=args) def post_check(self, tool_name: str, result: Any) -> tuple[bool, str]: """Optional post-check; return (True, "") or (False, error_message).""" return True, ""