fix: deep GPU integration, fix all ruff/mypy issues, add .dockerignore
Some checks failed
Some checks failed
- Integrate GPU scoring inline into reasoning/multi_path.py (auto-uses GPU when available) - Integrate GPU deduplication into multi_agent/consensus_engine.py - Add semantic_search() method to memory/semantic_graph.py with GPU acceleration - Integrate GPU training into self_improvement/training.py AutoTrainer - Fix all 758 ruff lint issues (whitespace, import sorting, unused imports, ambiguous vars, undefined names) - Fix all 40 mypy type errors across the codebase (no-any-return, union-attr, arg-type, etc.) - Fix deprecated ruff config keys (select/ignore -> [tool.ruff.lint]) - Add .dockerignore to exclude .venv/, tests/, docs/ from Docker builds - Add type hints and docstrings to verification/outcome.py - Fix E402 import ordering in witness_agent.py - Fix F821 undefined names in vector_pgvector.py and native.py - Fix E741 ambiguous variable names in reflective.py and recommender.py All 276 tests pass. 0 ruff errors. 0 mypy errors. Co-Authored-By: Nakamoto, S <defi@defi-oracle.io>
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
"""Tool registry, safe execution, connectors (docs, DB, code runner)."""
|
||||
|
||||
from fusionagi.tools.registry import ToolRegistry, ToolDef
|
||||
from fusionagi.tools.connectors import (
|
||||
BaseConnector,
|
||||
CodeRunnerConnector,
|
||||
DBConnector,
|
||||
DocsConnector,
|
||||
)
|
||||
from fusionagi.tools.registry import ToolDef, ToolRegistry
|
||||
from fusionagi.tools.runner import run_tool, run_tool_with_audit
|
||||
from fusionagi.tools.connectors import BaseConnector, DocsConnector, DBConnector, CodeRunnerConnector
|
||||
|
||||
__all__ = [
|
||||
"ToolRegistry",
|
||||
|
||||
@@ -6,8 +6,8 @@ import socket
|
||||
from typing import Any, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fusionagi.tools.registry import ToolDef
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi.tools.registry import ToolDef
|
||||
|
||||
# Default allowed path prefix for file tools. Deployers should pass an explicit scope (e.g. from config/env)
|
||||
# and not rely on cwd in production.
|
||||
@@ -32,46 +32,46 @@ class FileSizeError(Exception):
|
||||
def _normalize_path(path: str, scope: str) -> str:
|
||||
"""
|
||||
Normalize and validate a file path against scope.
|
||||
|
||||
|
||||
Resolves symlinks and prevents path traversal attacks.
|
||||
"""
|
||||
# Resolve to absolute path
|
||||
abs_path = os.path.abspath(path)
|
||||
|
||||
|
||||
# Resolve symlinks to get the real path
|
||||
try:
|
||||
real_path = os.path.realpath(abs_path)
|
||||
except OSError:
|
||||
real_path = abs_path
|
||||
|
||||
|
||||
# Normalize scope too
|
||||
real_scope = os.path.realpath(os.path.abspath(scope))
|
||||
|
||||
|
||||
# Check if path is under scope
|
||||
if not real_path.startswith(real_scope + os.sep) and real_path != real_scope:
|
||||
raise PermissionError(f"Path not allowed: {path} resolves outside {scope}")
|
||||
|
||||
|
||||
return real_path
|
||||
|
||||
|
||||
def _file_read(path: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str:
|
||||
"""
|
||||
Read file content; path must be under scope.
|
||||
|
||||
|
||||
Args:
|
||||
path: File path to read.
|
||||
scope: Allowed directory scope.
|
||||
max_size: Maximum file size in bytes.
|
||||
|
||||
|
||||
Returns:
|
||||
File contents as string.
|
||||
|
||||
|
||||
Raises:
|
||||
PermissionError: If path is outside scope.
|
||||
FileSizeError: If file exceeds max_size.
|
||||
"""
|
||||
real_path = _normalize_path(path, scope)
|
||||
|
||||
|
||||
# Check file size before reading
|
||||
try:
|
||||
file_size = os.path.getsize(real_path)
|
||||
@@ -79,7 +79,7 @@ def _file_read(path: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_F
|
||||
raise FileSizeError(f"File too large: {file_size} bytes (max {max_size})")
|
||||
except OSError as e:
|
||||
raise PermissionError(f"Cannot access file: {e}")
|
||||
|
||||
|
||||
with open(real_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
|
||||
@@ -87,16 +87,16 @@ def _file_read(path: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_F
|
||||
def _file_write(path: str, content: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str:
|
||||
"""
|
||||
Write content to file; path must be under scope.
|
||||
|
||||
|
||||
Args:
|
||||
path: File path to write.
|
||||
content: Content to write.
|
||||
scope: Allowed directory scope.
|
||||
max_size: Maximum content size in bytes.
|
||||
|
||||
|
||||
Returns:
|
||||
Success message with byte count.
|
||||
|
||||
|
||||
Raises:
|
||||
PermissionError: If path is outside scope.
|
||||
FileSizeError: If content exceeds max_size.
|
||||
@@ -105,16 +105,16 @@ def _file_write(path: str, content: str, scope: str = DEFAULT_FILE_SCOPE, max_si
|
||||
content_bytes = len(content.encode("utf-8"))
|
||||
if content_bytes > max_size:
|
||||
raise FileSizeError(f"Content too large: {content_bytes} bytes (max {max_size})")
|
||||
|
||||
|
||||
real_path = _normalize_path(path, scope)
|
||||
|
||||
|
||||
# Ensure parent directory exists
|
||||
parent_dir = os.path.dirname(real_path)
|
||||
if parent_dir and not os.path.exists(parent_dir):
|
||||
# Check if parent would be under scope
|
||||
_normalize_path(parent_dir, scope)
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
|
||||
with open(real_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return f"Wrote {content_bytes} bytes to {real_path}"
|
||||
@@ -141,14 +141,14 @@ def _is_private_ip(ip: str) -> bool:
|
||||
def _validate_url(url: str, allow_private: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL for SSRF protection.
|
||||
|
||||
|
||||
Args:
|
||||
url: URL to validate.
|
||||
allow_private: If True, allow private/internal IPs (default False).
|
||||
|
||||
|
||||
Returns:
|
||||
The validated URL.
|
||||
|
||||
|
||||
Raises:
|
||||
SSRFProtectionError: If URL is blocked for security reasons.
|
||||
"""
|
||||
@@ -156,27 +156,27 @@ def _validate_url(url: str, allow_private: bool = False) -> str:
|
||||
parsed = urlparse(url)
|
||||
except Exception as e:
|
||||
raise SSRFProtectionError(f"Invalid URL: {e}")
|
||||
|
||||
|
||||
# Only allow HTTP and HTTPS
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFProtectionError(f"URL scheme not allowed: {parsed.scheme}")
|
||||
|
||||
|
||||
# Must have a hostname
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise SSRFProtectionError("URL must have a hostname")
|
||||
|
||||
|
||||
# Block localhost variants
|
||||
localhost_patterns = ["localhost", "127.0.0.1", "::1", "0.0.0.0"]
|
||||
if hostname.lower() in localhost_patterns:
|
||||
raise SSRFProtectionError(f"Localhost URLs not allowed: {hostname}")
|
||||
|
||||
|
||||
# Block common internal hostnames
|
||||
internal_patterns = [".local", ".internal", ".corp", ".lan", ".home"]
|
||||
for pattern in internal_patterns:
|
||||
if hostname.lower().endswith(pattern):
|
||||
raise SSRFProtectionError(f"Internal hostname not allowed: {hostname}")
|
||||
|
||||
|
||||
if not allow_private:
|
||||
# Resolve hostname and check if IP is private
|
||||
try:
|
||||
@@ -184,24 +184,24 @@ def _validate_url(url: str, allow_private: bool = False) -> str:
|
||||
ips = socket.getaddrinfo(hostname, parsed.port or (443 if parsed.scheme == "https" else 80))
|
||||
for family, socktype, proto, canonname, sockaddr in ips:
|
||||
ip = sockaddr[0]
|
||||
if _is_private_ip(ip):
|
||||
if _is_private_ip(str(ip)):
|
||||
raise SSRFProtectionError(f"URL resolves to private IP: {ip}")
|
||||
except socket.gaierror as e:
|
||||
# DNS resolution failed - could be a security issue
|
||||
logger.warning(f"DNS resolution failed for {hostname}: {e}")
|
||||
raise SSRFProtectionError(f"Cannot resolve hostname: {hostname}")
|
||||
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def _http_get(url: str, allow_private: bool = False) -> str:
|
||||
"""
|
||||
Simple HTTP GET with SSRF protection.
|
||||
|
||||
|
||||
Args:
|
||||
url: URL to fetch.
|
||||
allow_private: If True, allow private/internal IPs (default False).
|
||||
|
||||
|
||||
Returns:
|
||||
Response text. On failure returns a string starting with 'Error: '.
|
||||
"""
|
||||
@@ -209,11 +209,11 @@ def _http_get(url: str, allow_private: bool = False) -> str:
|
||||
validated_url = _validate_url(url, allow_private=allow_private)
|
||||
except SSRFProtectionError as e:
|
||||
return f"Error: SSRF protection: {e}"
|
||||
|
||||
|
||||
try:
|
||||
import urllib.request
|
||||
with urllib.request.urlopen(validated_url, timeout=10) as resp:
|
||||
return resp.read().decode("utf-8", errors="replace")
|
||||
return str(resp.read().decode("utf-8", errors="replace"))
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fusionagi.tools.connectors.base import BaseConnector
|
||||
from fusionagi.tools.connectors.docs import DocsConnector
|
||||
from fusionagi.tools.connectors.db import DBConnector
|
||||
from fusionagi.tools.connectors.code_runner import CodeRunnerConnector
|
||||
from fusionagi.tools.connectors.db import DBConnector
|
||||
from fusionagi.tools.connectors.docs import DocsConnector
|
||||
|
||||
__all__ = ["BaseConnector", "DocsConnector", "DBConnector", "CodeRunnerConnector"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseConnector(ABC):
|
||||
name = "base"
|
||||
@abstractmethod
|
||||
|
||||
@@ -5,11 +5,12 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.governance.audit_log import AuditLog
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.tools.registry import ToolDef
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi.tools.registry import ToolDef
|
||||
|
||||
|
||||
class ToolValidationError(Exception):
|
||||
@@ -24,39 +25,39 @@ class ToolValidationError(Exception):
|
||||
def validate_args(tool: ToolDef, args: dict[str, Any]) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate arguments against tool's JSON schema.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message). error_message is empty if valid.
|
||||
"""
|
||||
schema = tool.parameters_schema
|
||||
if not schema:
|
||||
return True, ""
|
||||
|
||||
|
||||
# Basic JSON schema validation (without external dependency)
|
||||
schema_type = schema.get("type", "object")
|
||||
if schema_type != "object":
|
||||
return True, "" # Only validate object schemas
|
||||
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
|
||||
# Check required fields
|
||||
for field in required:
|
||||
if field not in args:
|
||||
return False, f"Missing required argument: {field}"
|
||||
|
||||
|
||||
# Check types of provided fields
|
||||
for field, value in args.items():
|
||||
if field not in properties:
|
||||
# Allow extra fields by default (additionalProperties: true is common)
|
||||
continue
|
||||
|
||||
|
||||
prop_schema = properties[field]
|
||||
prop_type = prop_schema.get("type")
|
||||
|
||||
|
||||
if prop_type is None:
|
||||
continue
|
||||
|
||||
|
||||
# Type checking
|
||||
type_valid = True
|
||||
if prop_type == "string":
|
||||
@@ -73,16 +74,16 @@ def validate_args(tool: ToolDef, args: dict[str, Any]) -> tuple[bool, str]:
|
||||
type_valid = isinstance(value, dict)
|
||||
elif prop_type == "null":
|
||||
type_valid = value is None
|
||||
|
||||
|
||||
if not type_valid:
|
||||
return False, f"Argument '{field}' must be of type {prop_type}, got {type(value).__name__}"
|
||||
|
||||
|
||||
# String constraints
|
||||
if prop_type == "string" and isinstance(value, str):
|
||||
min_len = prop_schema.get("minLength")
|
||||
max_len = prop_schema.get("maxLength")
|
||||
pattern = prop_schema.get("pattern")
|
||||
|
||||
|
||||
if min_len is not None and len(value) < min_len:
|
||||
return False, f"Argument '{field}' must be at least {min_len} characters"
|
||||
if max_len is not None and len(value) > max_len:
|
||||
@@ -91,14 +92,14 @@ def validate_args(tool: ToolDef, args: dict[str, Any]) -> tuple[bool, str]:
|
||||
import re
|
||||
if not re.match(pattern, value):
|
||||
return False, f"Argument '{field}' does not match pattern: {pattern}"
|
||||
|
||||
|
||||
# Number constraints
|
||||
if prop_type in ("integer", "number") and isinstance(value, (int, float)):
|
||||
minimum = prop_schema.get("minimum")
|
||||
maximum = prop_schema.get("maximum")
|
||||
exclusive_min = prop_schema.get("exclusiveMinimum")
|
||||
exclusive_max = prop_schema.get("exclusiveMaximum")
|
||||
|
||||
|
||||
if minimum is not None and value < minimum:
|
||||
return False, f"Argument '{field}' must be >= {minimum}"
|
||||
if maximum is not None and value > maximum:
|
||||
@@ -107,12 +108,12 @@ def validate_args(tool: ToolDef, args: dict[str, Any]) -> tuple[bool, str]:
|
||||
return False, f"Argument '{field}' must be > {exclusive_min}"
|
||||
if exclusive_max is not None and value >= exclusive_max:
|
||||
return False, f"Argument '{field}' must be < {exclusive_max}"
|
||||
|
||||
|
||||
# Enum constraint
|
||||
enum = prop_schema.get("enum")
|
||||
if enum is not None and value not in enum:
|
||||
return False, f"Argument '{field}' must be one of: {enum}"
|
||||
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
@@ -124,13 +125,13 @@ def run_tool(
|
||||
) -> tuple[Any, dict[str, Any]]:
|
||||
"""
|
||||
Invoke tool.fn(args) with optional validation and timeout.
|
||||
|
||||
|
||||
Args:
|
||||
tool: The tool definition to execute.
|
||||
args: Arguments to pass to the tool function.
|
||||
timeout_seconds: Override timeout (uses tool.timeout_seconds if None).
|
||||
validate: Whether to validate args against tool's schema (default True).
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (result, log_entry). On error, result is None and log_entry contains error.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user