"""Built-in tools: file read (scoped), HTTP GET (with SSRF protection), query state.""" import ipaddress import os import socket from typing import Any, Callable from urllib.parse import urlparse from fusionagi._logger import logger from fusionagi.tools.registry import ToolDef # Default allowed path prefix for file tools. Deployers should pass an explicit scope (e.g. from config/env) # and not rely on cwd in production. DEFAULT_FILE_SCOPE = os.path.abspath(os.getcwd()) # Maximum file size for read/write operations (10MB) MAX_FILE_SIZE = 10 * 1024 * 1024 class SSRFProtectionError(Exception): """Raised when a URL is blocked for SSRF protection.""" pass class FileSizeError(Exception): """Raised when file size exceeds limit.""" pass def _normalize_path(path: str, scope: str) -> str: """ Normalize and validate a file path against scope. Resolves symlinks and prevents path traversal attacks. """ # Resolve to absolute path abs_path = os.path.abspath(path) # Resolve symlinks to get the real path try: real_path = os.path.realpath(abs_path) except OSError: real_path = abs_path # Normalize scope too real_scope = os.path.realpath(os.path.abspath(scope)) # Check if path is under scope if not real_path.startswith(real_scope + os.sep) and real_path != real_scope: raise PermissionError(f"Path not allowed: {path} resolves outside {scope}") return real_path def _file_read(path: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str: """ Read file content; path must be under scope. Args: path: File path to read. scope: Allowed directory scope. max_size: Maximum file size in bytes. Returns: File contents as string. Raises: PermissionError: If path is outside scope. FileSizeError: If file exceeds max_size. """ real_path = _normalize_path(path, scope) # Check file size before reading try: file_size = os.path.getsize(real_path) if file_size > max_size: raise FileSizeError(f"File too large: {file_size} bytes (max {max_size})") except OSError as e: raise PermissionError(f"Cannot access file: {e}") with open(real_path, "r", encoding="utf-8", errors="replace") as f: return f.read() def _file_write(path: str, content: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str: """ Write content to file; path must be under scope. Args: path: File path to write. content: Content to write. scope: Allowed directory scope. max_size: Maximum content size in bytes. Returns: Success message with byte count. Raises: PermissionError: If path is outside scope. FileSizeError: If content exceeds max_size. """ # Check content size before writing content_bytes = len(content.encode("utf-8")) if content_bytes > max_size: raise FileSizeError(f"Content too large: {content_bytes} bytes (max {max_size})") real_path = _normalize_path(path, scope) # Ensure parent directory exists parent_dir = os.path.dirname(real_path) if parent_dir and not os.path.exists(parent_dir): # Check if parent would be under scope _normalize_path(parent_dir, scope) os.makedirs(parent_dir, exist_ok=True) with open(real_path, "w", encoding="utf-8") as f: f.write(content) return f"Wrote {content_bytes} bytes to {real_path}" def _is_private_ip(ip: str) -> bool: """Check if an IP address is private, loopback, or otherwise unsafe.""" try: addr = ipaddress.ip_address(ip) return ( addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_multicast or addr.is_reserved or addr.is_unspecified # Block IPv6 mapped IPv4 addresses or (isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None) ) except ValueError: return True # Invalid IP is treated as unsafe def _validate_url(url: str, allow_private: bool = False) -> str: """ Validate a URL for SSRF protection. Args: url: URL to validate. allow_private: If True, allow private/internal IPs (default False). Returns: The validated URL. Raises: SSRFProtectionError: If URL is blocked for security reasons. """ try: parsed = urlparse(url) except Exception as e: raise SSRFProtectionError(f"Invalid URL: {e}") # Only allow HTTP and HTTPS if parsed.scheme not in ("http", "https"): raise SSRFProtectionError(f"URL scheme not allowed: {parsed.scheme}") # Must have a hostname hostname = parsed.hostname if not hostname: raise SSRFProtectionError("URL must have a hostname") # Block localhost variants localhost_patterns = ["localhost", "127.0.0.1", "::1", "0.0.0.0"] if hostname.lower() in localhost_patterns: raise SSRFProtectionError(f"Localhost URLs not allowed: {hostname}") # Block common internal hostnames internal_patterns = [".local", ".internal", ".corp", ".lan", ".home"] for pattern in internal_patterns: if hostname.lower().endswith(pattern): raise SSRFProtectionError(f"Internal hostname not allowed: {hostname}") if not allow_private: # Resolve hostname and check if IP is private try: # Get all IP addresses for the hostname ips = socket.getaddrinfo(hostname, parsed.port or (443 if parsed.scheme == "https" else 80)) for family, socktype, proto, canonname, sockaddr in ips: ip = sockaddr[0] if _is_private_ip(str(ip)): raise SSRFProtectionError(f"URL resolves to private IP: {ip}") except socket.gaierror as e: # DNS resolution failed - could be a security issue logger.warning(f"DNS resolution failed for {hostname}: {e}") raise SSRFProtectionError(f"Cannot resolve hostname: {hostname}") return url def _http_get(url: str, allow_private: bool = False) -> str: """ Simple HTTP GET with SSRF protection. Args: url: URL to fetch. allow_private: If True, allow private/internal IPs (default False). Returns: Response text. On failure returns a string starting with 'Error: '. """ try: validated_url = _validate_url(url, allow_private=allow_private) except SSRFProtectionError as e: return f"Error: SSRF protection: {e}" try: import urllib.request with urllib.request.urlopen(validated_url, timeout=10) as resp: return str(resp.read().decode("utf-8", errors="replace")) except Exception as e: return f"Error: {e}" def make_file_read_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef: """File read tool with path scope.""" def fn(path: str) -> str: return _file_read(path, scope=scope) return ToolDef( name="file_read", description="Read file content; path must be under allowed scope", fn=fn, parameters_schema={ "type": "object", "properties": {"path": {"type": "string", "description": "File path"}}, "required": ["path"], }, permission_scope=["file"], timeout_seconds=5.0, ) def make_file_write_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef: """File write tool with path scope.""" def fn(path: str, content: str) -> str: return _file_write(path, content, scope=scope) return ToolDef( name="file_write", description="Write content to file; path must be under allowed scope", fn=fn, parameters_schema={ "type": "object", "properties": { "path": {"type": "string", "description": "File path"}, "content": {"type": "string", "description": "Content to write"}, }, "required": ["path", "content"], }, permission_scope=["file"], timeout_seconds=5.0, ) def make_http_get_tool() -> ToolDef: """HTTP GET tool.""" return ToolDef( name="http_get", description="Perform HTTP GET request and return response body", fn=_http_get, parameters_schema={ "type": "object", "properties": {"url": {"type": "string", "description": "URL to fetch"}}, "required": ["url"], }, permission_scope=["network"], timeout_seconds=15.0, ) def make_query_state_tool(get_state_fn: Callable[[str], Any]) -> ToolDef: """Internal tool: query task state (injected get_state_fn(task_id) -> state/trace).""" def fn(task_id: str) -> Any: return get_state_fn(task_id) return ToolDef( name="query_state", description="Query task state and trace (internal)", fn=fn, parameters_schema={ "type": "object", "properties": {"task_id": {"type": "string"}}, "required": ["task_id"], }, permission_scope=["internal"], timeout_seconds=2.0, )