"""Built-in tools: file read, HTTP GET, query state. In advisory mode (default), scope violations and SSRF detections are logged as warnings but the operation proceeds. The system learns from outcomes rather than being prevented from exploring. """ import ipaddress import os import socket from typing import Any, Callable from urllib.parse import urlparse from fusionagi._logger import logger from fusionagi.tools.registry import ToolDef # Default allowed path prefix for file tools. Deployers should pass an explicit scope (e.g. from config/env) # and not rely on cwd in production. DEFAULT_FILE_SCOPE = os.path.abspath(os.getcwd()) # Default file size limit (configurable, None = unlimited) MAX_FILE_SIZE: int | None = None class SSRFProtectionError(Exception): """Raised when a URL is blocked for SSRF protection.""" pass class FileSizeError(Exception): """Raised when file size exceeds limit.""" pass def _normalize_path(path: str, scope: str, advisory: bool = True) -> str: """ Normalize a file path and check scope. In advisory mode (default), out-of-scope paths are logged but allowed through. The system learns from outcomes. """ abs_path = os.path.abspath(path) try: real_path = os.path.realpath(abs_path) except OSError: real_path = abs_path real_scope = os.path.realpath(os.path.abspath(scope)) if not real_path.startswith(real_scope + os.sep) and real_path != real_scope: if advisory: logger.info( "File scope advisory: path outside scope (proceeding)", extra={"path": path, "scope": scope, "mode": "advisory"}, ) else: raise PermissionError(f"Path not allowed: {path} resolves outside {scope}") return real_path def _file_read( path: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int | None = MAX_FILE_SIZE, advisory: bool = True, ) -> str: """ Read file content. Scope and size checks are advisory by default. Args: path: File path to read. scope: Allowed directory scope. max_size: Maximum file size in bytes (``None`` = unlimited). advisory: If True, violations are logged but allowed. Returns: File contents as string. """ real_path = _normalize_path(path, scope, advisory=advisory) if max_size is not None: try: file_size = os.path.getsize(real_path) if file_size > max_size: if advisory: logger.info( "File size advisory: file exceeds limit (proceeding)", extra={"path": path, "size": file_size, "limit": max_size, "mode": "advisory"}, ) else: raise FileSizeError(f"File too large: {file_size} bytes (max {max_size})") except OSError as e: raise PermissionError(f"Cannot access file: {e}") with open(real_path, "r", encoding="utf-8", errors="replace") as f: return f.read() def _file_write( path: str, content: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int | None = MAX_FILE_SIZE, advisory: bool = True, ) -> str: """ Write content to file. Scope and size checks are advisory by default. Args: path: File path to write. content: Content to write. scope: Allowed directory scope. max_size: Maximum content size in bytes (``None`` = unlimited). advisory: If True, violations are logged but allowed. Returns: Success message with byte count. """ content_bytes = len(content.encode("utf-8")) if max_size is not None and content_bytes > max_size: if advisory: logger.info( "File size advisory: content exceeds limit (proceeding)", extra={"path": path, "size": content_bytes, "limit": max_size, "mode": "advisory"}, ) else: raise FileSizeError(f"Content too large: {content_bytes} bytes (max {max_size})") real_path = _normalize_path(path, scope, advisory=advisory) parent_dir = os.path.dirname(real_path) if parent_dir and not os.path.exists(parent_dir): _normalize_path(parent_dir, scope, advisory=advisory) os.makedirs(parent_dir, exist_ok=True) with open(real_path, "w", encoding="utf-8") as f: f.write(content) return f"Wrote {content_bytes} bytes to {real_path}" def _is_private_ip(ip: str) -> bool: """Check if an IP address is private, loopback, or otherwise unsafe.""" try: addr = ipaddress.ip_address(ip) return ( addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_multicast or addr.is_reserved or addr.is_unspecified # Block IPv6 mapped IPv4 addresses or (isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None) ) except ValueError: return True # Invalid IP is treated as unsafe def _validate_url(url: str, allow_private: bool = True, advisory: bool = True) -> str: """ Validate a URL. In advisory mode (default), issues are logged but the URL is allowed through. Args: url: URL to validate. allow_private: If True (default), allow private/internal IPs. advisory: If True, log issues as advisories instead of raising. Returns: The validated URL. """ try: parsed = urlparse(url) except Exception as e: if advisory: logger.info("URL advisory: parse error (proceeding)", extra={"url": url[:100], "error": str(e)}) return url raise SSRFProtectionError(f"Invalid URL: {e}") if parsed.scheme not in ("http", "https"): if advisory: logger.info("URL advisory: non-HTTP scheme (proceeding)", extra={"scheme": parsed.scheme}) return url raise SSRFProtectionError(f"URL scheme not allowed: {parsed.scheme}") hostname = parsed.hostname if not hostname: if advisory: logger.info("URL advisory: no hostname (proceeding)", extra={"url": url[:100]}) return url raise SSRFProtectionError("URL must have a hostname") localhost_patterns = ["localhost", "127.0.0.1", "::1", "0.0.0.0"] if hostname.lower() in localhost_patterns: if advisory: logger.info("URL advisory: localhost detected (proceeding)", extra={"hostname": hostname}) return url raise SSRFProtectionError(f"Localhost URLs not allowed: {hostname}") internal_patterns = [".local", ".internal", ".corp", ".lan", ".home"] for pattern in internal_patterns: if hostname.lower().endswith(pattern): if advisory: logger.info("URL advisory: internal hostname (proceeding)", extra={"hostname": hostname}) return url raise SSRFProtectionError(f"Internal hostname not allowed: {hostname}") if not allow_private: try: ips = socket.getaddrinfo(hostname, parsed.port or (443 if parsed.scheme == "https" else 80)) for family, socktype, proto, canonname, sockaddr in ips: ip = sockaddr[0] if _is_private_ip(str(ip)): if advisory: logger.info("URL advisory: private IP (proceeding)", extra={"ip": ip}) return url raise SSRFProtectionError(f"URL resolves to private IP: {ip}") except socket.gaierror as e: logger.warning(f"DNS resolution failed for {hostname}: {e}") if not advisory: raise SSRFProtectionError(f"Cannot resolve hostname: {hostname}") return url def _http_get(url: str, allow_private: bool = True) -> str: """ HTTP GET with advisory URL validation. Args: url: URL to fetch. allow_private: If True (default), allow private/internal IPs. Returns: Response text. On failure returns a string starting with 'Error: '. """ try: validated_url = _validate_url(url, allow_private=allow_private, advisory=True) except SSRFProtectionError as e: return f"Error: SSRF protection: {e}" try: import urllib.request with urllib.request.urlopen(validated_url, timeout=10) as resp: return str(resp.read().decode("utf-8", errors="replace")) except Exception as e: return f"Error: {e}" def make_file_read_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef: """File read tool with path scope.""" def fn(path: str) -> str: return _file_read(path, scope=scope) return ToolDef( name="file_read", description="Read file content; path must be under allowed scope", fn=fn, parameters_schema={ "type": "object", "properties": {"path": {"type": "string", "description": "File path"}}, "required": ["path"], }, permission_scope=["file"], timeout_seconds=5.0, ) def make_file_write_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef: """File write tool with path scope.""" def fn(path: str, content: str) -> str: return _file_write(path, content, scope=scope) return ToolDef( name="file_write", description="Write content to file; path must be under allowed scope", fn=fn, parameters_schema={ "type": "object", "properties": { "path": {"type": "string", "description": "File path"}, "content": {"type": "string", "description": "Content to write"}, }, "required": ["path", "content"], }, permission_scope=["file"], timeout_seconds=5.0, ) def make_http_get_tool() -> ToolDef: """HTTP GET tool.""" return ToolDef( name="http_get", description="Perform HTTP GET request and return response body", fn=_http_get, parameters_schema={ "type": "object", "properties": {"url": {"type": "string", "description": "URL to fetch"}}, "required": ["url"], }, permission_scope=["network"], timeout_seconds=15.0, ) def make_query_state_tool(get_state_fn: Callable[[str], Any]) -> ToolDef: """Internal tool: query task state (injected get_state_fn(task_id) -> state/trace).""" def fn(task_id: str) -> Any: return get_state_fn(task_id) return ToolDef( name="query_state", description="Query task state and trace (internal)", fn=fn, parameters_schema={ "type": "object", "properties": {"task_id": {"type": "string"}}, "required": ["task_id"], }, permission_scope=["internal"], timeout_seconds=2.0, )