Initial commit: add .gitignore and README
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled

This commit is contained in:
defiQUG
2026-02-09 21:51:42 -08:00
commit c052b07662
3146 changed files with 808305 additions and 0 deletions

View File

@@ -0,0 +1,56 @@
"""Reasoning engine: chain-of-thought, tree-of-thought, and native symbolic reasoning."""
from fusionagi.reasoning.cot import (
build_cot_messages,
run_chain_of_thought,
)
from fusionagi.reasoning.tot import (
run_tree_of_thought,
run_tree_of_thought_detailed,
ThoughtBranch,
ThoughtNode,
ToTResult,
expand_node,
prune_subtree,
merge_subtrees,
)
from fusionagi.reasoning.native import (
NativeReasoningProvider,
analyze_prompt,
produce_head_output,
PromptAnalysis,
)
from fusionagi.reasoning.decomposition import decompose_recursive
from fusionagi.reasoning.multi_path import generate_and_score_parallel
from fusionagi.reasoning.recomposition import recompose, RecomposedResponse
from fusionagi.reasoning.meta_reasoning import (
challenge_assumptions,
detect_contradictions,
revisit_node,
)
__all__ = [
"build_cot_messages",
"run_chain_of_thought",
"run_tree_of_thought",
"run_tree_of_thought_detailed",
"ThoughtBranch",
"ThoughtNode",
"ToTResult",
"expand_node",
"prune_subtree",
"merge_subtrees",
"NativeReasoningProvider",
"analyze_prompt",
"produce_head_output",
"PromptAnalysis",
"decompose_recursive",
"load_context_for_reasoning",
"build_compact_prompt",
"generate_and_score_parallel",
"recompose",
"RecomposedResponse",
"challenge_assumptions",
"detect_contradictions",
"revisit_node",
]

View File

@@ -0,0 +1,71 @@
"""Retrieve-by-reference: load context for reasoning without token overflow."""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
from fusionagi.schemas.atomic import AtomicSemanticUnit
from fusionagi.memory.sharding import Shard, shard_context
@runtime_checkable
class SemanticGraphLike(Protocol):
"""Protocol for semantic graph access."""
def get_unit(self, unit_id: str) -> AtomicSemanticUnit | None: ...
def query_units(self, unit_ids: list[str] | None = None, limit: int = 100) -> list[AtomicSemanticUnit]: ...
@runtime_checkable
class SharderLike(Protocol):
"""Protocol for sharding."""
def __call__(self, units: list[AtomicSemanticUnit], max_cluster_size: int) -> list[Shard]: ...
def load_context_for_reasoning(
query_units: list[AtomicSemanticUnit],
semantic_graph: SemanticGraphLike | None = None,
sharder: SharderLike | None = None,
max_cluster_size: int = 20,
) -> dict[str, Any]:
"""
Fetch relevant shards/units by reference for reasoning.
Returns structured context (unit IDs + summaries) rather than raw text.
"""
shard_fn = sharder or shard_context
shards = shard_fn(query_units, max_cluster_size=max_cluster_size)
unit_refs: list[str] = []
unit_summaries: dict[str, str] = {}
for u in query_units:
unit_refs.append(u.unit_id)
unit_summaries[u.unit_id] = u.content[:150]
if semantic_graph:
for s in shards:
for uid in s.unit_ids:
if uid not in unit_summaries:
found = semantic_graph.get_unit(uid)
if found:
unit_summaries[uid] = found.content[:150]
unit_refs.append(uid)
return {"unit_refs": unit_refs, "shards": shards, "unit_summaries": unit_summaries}
def build_compact_prompt(
units: list[AtomicSemanticUnit],
max_chars: int = 4000,
) -> str:
"""Materialize text for units that fit; rest stay as references."""
parts: list[str] = []
total = 0
refs: list[str] = []
for u in units:
line = f"[{u.unit_id}] {u.content}\n"
if total + len(line) <= max_chars:
parts.append(line)
total += len(line)
else:
refs.append(u.unit_id)
if refs:
parts.append(f"\n[References: {', '.join(refs[:20])}]")
return "".join(parts) if parts else "[No units]"

View File

@@ -0,0 +1,43 @@
"""Chain-of-thought: prompt structure and trace storage."""
from typing import Any
from fusionagi.adapters.base import LLMAdapter
COT_SYSTEM = """You reason step by step. For each step, state your thought clearly.
Output your final conclusion or recommendation after your reasoning."""
def build_cot_messages(
query: str,
context: str | None = None,
trace_so_far: list[str] | None = None,
) -> list[dict[str, str]]:
"""Build message list for chain-of-thought: system + optional context + user query."""
messages: list[dict[str, str]] = [{"role": "system", "content": COT_SYSTEM}]
if context:
messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuery: {query}"})
else:
messages.append({"role": "user", "content": query})
if trace_so_far:
assistant_content = "\n".join(trace_so_far)
messages.append({"role": "assistant", "content": assistant_content})
messages.append({"role": "user", "content": "Continue."})
return messages
def run_chain_of_thought(
adapter: LLMAdapter,
query: str,
context: str | None = None,
trace_so_far: list[str] | None = None,
**kwargs: Any,
) -> tuple[str, list[str]]:
"""
Run one CoT step; return (full_response, trace_entries).
Trace entries can be stored and passed as trace_so_far for multi-step CoT.
"""
messages = build_cot_messages(query, context=context, trace_so_far=trace_so_far)
response = adapter.complete(messages, **kwargs)
trace = (trace_so_far or []) + [response]
return response, trace

View File

@@ -0,0 +1,207 @@
"""Recursive semantic decomposition: split text into atomic units."""
from __future__ import annotations
import re
import uuid
from typing import Any
from fusionagi.reasoning.native import analyze_prompt
from fusionagi.schemas.atomic import (
AtomicSemanticUnit,
AtomicUnitType,
DecompositionResult,
RelationType,
SemanticRelation,
)
from fusionagi._logger import logger
def _make_unit_id(prefix: str = "asu") -> str:
"""Generate unique unit ID."""
return f"{prefix}_{uuid.uuid4().hex[:12]}"
def _is_atomic(text: str, min_words: int = 3) -> bool:
"""Check if text is irreducible (atomic)."""
content = " ".join(text.split()).strip()
if not content or len(content) < 10:
return True
words = len(content.split())
return words <= min_words
def _extract_questions(text: str) -> list[str]:
"""Extract explicit questions from text."""
questions: list[str] = []
content = " ".join(text.split()).strip()
q_parts = re.split(r"\?+", content)
for part in q_parts[:-1]:
q = part.strip()
if len(q) > 10:
questions.append(q + "?")
if not questions and any(w in content.lower() for w in ["how", "what", "why", "when", "where", "who"]):
questions.append(content)
return questions[:5]
def _extract_constraints(text: str) -> list[str]:
"""Extract constraint signals from text."""
constraints: list[str] = []
patterns = [
r"must\s+(\w[\w\s]+?)(?:\.|$)",
r"should\s+(\w[\w\s]+?)(?:\.|$)",
r"cannot\s+(\w[\w\s]+?)(?:\.|$)",
r"require[sd]?\s+(\w[\w\s]+?)(?:\.|$)",
r"constraint[s]?:\s*(\w[\w\s]+?)(?:\.|$)",
r"assume[sd]?\s+(\w[\w\s]+?)(?:\.|$)",
]
for pat in patterns:
for m in re.finditer(pat, text, re.I):
constraints.append(m.group(1).strip())
return list(dict.fromkeys(constraints))[:10]
def _extract_entities(text: str) -> list[str]:
"""Extract entity-like phrases."""
entities = re.findall(r'"([^"]+)"', text)
entities += re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", text)
return list(dict.fromkeys(e for e in entities if len(e) > 2))[:10]
def decompose_recursive(
text: str,
max_depth: int = 3,
parent_id: str | None = None,
current_depth: int = 0,
source_ref: str | None = None,
) -> DecompositionResult:
"""
Recursively decompose text into atomic semantic units.
Extracts entities, constraints, intents, assumptions, questions; recurses
on non-atomic segments. Integrates with native analyze_prompt for intent
and domain signals.
Args:
text: Input text to decompose.
max_depth: Maximum recursion depth.
parent_id: Parent unit ID for decomposition tree.
current_depth: Current recursion depth.
source_ref: Optional source reference.
Returns:
DecompositionResult with units and relations.
"""
content = " ".join(text.split()).strip()
if not content:
return DecompositionResult(units=[], relations=[], depth=current_depth)
units: list[AtomicSemanticUnit] = []
relations: list[SemanticRelation] = []
analysis = analyze_prompt(content)
# Root unit for this segment (if not already atomic)
root_id = _make_unit_id()
root_unit = AtomicSemanticUnit(
unit_id=root_id,
content=content[:500] + ("..." if len(content) > 500 else ""),
type=AtomicUnitType.INTENT if analysis.intent == "question" else AtomicUnitType.FACT,
confidence=0.8,
parent_id=parent_id,
source_ref=source_ref,
metadata={"intent": analysis.intent},
)
units.append(root_unit)
if parent_id:
relations.append(
SemanticRelation(from_id=parent_id, to_id=root_id, relation_type=RelationType.LOGICAL)
)
# Extract questions as atomic units
for q in _extract_questions(content):
q_id = _make_unit_id()
units.append(
AtomicSemanticUnit(
unit_id=q_id,
content=q,
type=AtomicUnitType.QUESTION,
confidence=0.9,
parent_id=root_id,
source_ref=source_ref,
)
)
relations.append(
SemanticRelation(from_id=root_id, to_id=q_id, relation_type=RelationType.LOGICAL)
)
# Extract constraints as atomic units
for c in _extract_constraints(content):
c_id = _make_unit_id()
units.append(
AtomicSemanticUnit(
unit_id=c_id,
content=c,
type=AtomicUnitType.CONSTRAINT,
confidence=0.85,
parent_id=root_id,
source_ref=source_ref,
)
)
relations.append(
SemanticRelation(from_id=root_id, to_id=c_id, relation_type=RelationType.LOGICAL)
)
# Extract entities as atomic units
for e in _extract_entities(content):
e_id = _make_unit_id()
units.append(
AtomicSemanticUnit(
unit_id=e_id,
content=e,
type=AtomicUnitType.FACT,
confidence=0.9,
parent_id=root_id,
source_ref=source_ref,
)
)
relations.append(
SemanticRelation(from_id=root_id, to_id=e_id, relation_type=RelationType.LOGICAL)
)
# If not atomic and depth allows, split and recurse
if not _is_atomic(content, min_words=8) and current_depth < max_depth:
sentences = re.split(r"[.!?]\s+", content)
if len(sentences) > 1:
for sent in sentences:
sent = sent.strip()
if len(sent) > 20:
sub = decompose_recursive(
sent,
max_depth=max_depth,
parent_id=root_id,
current_depth=current_depth + 1,
source_ref=source_ref,
)
units.extend(sub.units)
relations.extend(sub.relations)
# Dedupe by unit_id
seen: set[str] = set()
unique_units: list[AtomicSemanticUnit] = []
for u in units:
if u.unit_id not in seen:
seen.add(u.unit_id)
unique_units.append(u)
logger.debug(
"Decomposition complete",
extra={"depth": current_depth, "units": len(unique_units), "relations": len(relations)},
)
return DecompositionResult(
units=unique_units,
relations=relations,
depth=current_depth,
)

View File

@@ -0,0 +1,85 @@
"""Meta-reasoning: challenge assumptions, detect contradictions, revisit nodes."""
from __future__ import annotations
from typing import Any
from fusionagi.schemas.atomic import AtomicSemanticUnit, AtomicUnitType
from fusionagi.reasoning.tot import ThoughtNode, expand_node
from fusionagi._logger import logger
def challenge_assumptions(
units: list[AtomicSemanticUnit],
current_conclusion: str,
) -> list[str]:
"""
Identify and flag assumptions in units that support the conclusion.
"""
flagged: list[str] = []
conclusion_lower = current_conclusion.lower()
for u in units:
if u.type == AtomicUnitType.ASSUMPTION:
flagged.append(u.content)
elif "assume" in u.content.lower() or "assumption" in u.content.lower():
flagged.append(u.content)
elif u.type == AtomicUnitType.CONSTRAINT:
if any(w in conclusion_lower for w in ["must", "should", "require"]):
flagged.append(f"Constraint may be assumed: {u.content[:100]}")
logger.debug("Assumptions flagged", extra={"count": len(flagged)})
return flagged
def detect_contradictions(
units: list[AtomicSemanticUnit],
) -> list[tuple[str, str]]:
"""
Find conflicting units (heuristic: negation mismatch, same subject).
"""
neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't", "don't", "doesn't"}
pairs: list[tuple[str, str]] = []
for i, a in enumerate(units):
wa = set(a.content.lower().split())
for b in units[i + 1:]:
wb = set(b.content.lower().split())
a_neg = bool(wa & neg_words)
b_neg = bool(wb & neg_words)
if a_neg != b_neg:
overlap = len(wa & wb) / max(len(wa), 1)
if overlap > 0.2:
pairs.append((a.unit_id, b.unit_id))
logger.debug("Contradictions detected", extra={"count": len(pairs)})
return pairs
def revisit_node(
tree: ThoughtNode | None,
node_id: str,
new_evidence: str,
) -> ThoughtNode | None:
"""
Re-expand a node when new evidence arrives.
Creates a new child with the new evidence.
"""
if tree is None:
return None
def find_node(n: ThoughtNode) -> ThoughtNode | None:
if n.node_id == node_id:
return n
for c in n.children:
found = find_node(c)
if found:
return found
return None
node = find_node(tree)
if not node:
return tree
child = expand_node(node, new_evidence)
child.metadata["revisit_evidence"] = new_evidence[:200]
return tree

View File

@@ -0,0 +1,50 @@
"""Multi-path inference: parallel hypothesis generation and scoring."""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable
from fusionagi.schemas.atomic import AtomicSemanticUnit
from fusionagi.reasoning.tot import ThoughtNode
from fusionagi._logger import logger
def _score_coherence(node: ThoughtNode, _units: list[AtomicSemanticUnit]) -> float:
return node.score * (0.9 + 0.1 * min(1, len(node.trace) / 5))
def _score_consistency(node: ThoughtNode, units: list[AtomicSemanticUnit]) -> float:
if not units:
return 0.5
unit_content = " ".join(u.content.lower() for u in units)
thought_words = set(node.thought.lower().split())
unit_words = set(unit_content.split())
overlap = len(thought_words & unit_words) / max(len(thought_words), 1)
return min(1.0, overlap * 2)
def generate_and_score_parallel(
hypotheses: list[str],
units: list[AtomicSemanticUnit],
score_fn: Callable[[ThoughtNode, list[AtomicSemanticUnit]], float] | None = None,
) -> list[tuple[ThoughtNode, float]]:
"""Score multiple hypotheses in parallel."""
score_fn = score_fn or (lambda n, u: _score_coherence(n, u) * 0.5 + _score_consistency(n, u) * 0.5)
def score_one(h: str, i: int) -> tuple[ThoughtNode, float]:
node = ThoughtNode(thought=h, trace=[h], unit_refs=[u.unit_id for u in units[:10]])
s = score_fn(node, units)
node.score = s
return node, s
results: list[tuple[ThoughtNode, float]] = []
with ThreadPoolExecutor(max_workers=min(len(hypotheses), 8)) as ex:
futures = {ex.submit(score_one, h, i): i for i, h in enumerate(hypotheses)}
for future in as_completed(futures):
try:
node, score = future.result()
results.append((node, score))
except Exception as e:
logger.warning("Multi-path score failed", extra={"error": str(e)})
return results

View File

@@ -0,0 +1,319 @@
"""
Native reasoning engine: symbolic, rule-based analysis independent of external LLMs.
Produces structured HeadOutput from prompt analysis using:
- Keyword and pattern extraction
- Head-specific domain logic
- Persona-driven synthesis
- No external API calls
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any
from fusionagi.config.head_personas import get_persona
from fusionagi.schemas.grounding import Citation
from fusionagi.schemas.head import HeadClaim, HeadId, HeadOutput, HeadRisk
@dataclass
class PromptAnalysis:
"""Structured analysis of a user prompt from native reasoning."""
intent: str = ""
entities: list[str] = field(default_factory=list)
constraints: list[str] = field(default_factory=list)
questions: list[str] = field(default_factory=list)
domain_signals: dict[str, float] = field(default_factory=dict)
keywords: set[str] = field(default_factory=set)
# Domain keywords per head: presence boosts relevance and shapes claims
HEAD_DOMAIN_KEYWORDS: dict[HeadId, set[str]] = {
HeadId.LOGIC: {"logic", "contradiction", "correct", "valid", "proof", "assumption", "therefore", "implies"},
HeadId.RESEARCH: {"source", "cite", "reference", "evidence", "study", "paper", "find", "search"},
HeadId.SYSTEMS: {"architecture", "scalability", "dependency", "system", "design", "component", "service"},
HeadId.STRATEGY: {"strategy", "roadmap", "priority", "tradeoff", "plan", "goal", "long-term"},
HeadId.PRODUCT: {"user", "ux", "product", "flow", "design", "experience", "interface"},
HeadId.SECURITY: {"security", "auth", "threat", "vulnerability", "secret", "encrypt", "attack"},
HeadId.SAFETY: {"safety", "harm", "policy", "ethical", "comply", "risk", "prevent"},
HeadId.RELIABILITY: {"reliability", "slo", "failover", "observability", "test", "load", "uptime"},
HeadId.COST: {"cost", "budget", "performance", "cache", "token", "efficient", "expensive"},
HeadId.DATA: {"data", "schema", "privacy", "retention", "memory", "storage", "database"},
HeadId.DEVEX: {"dev", "ci", "cd", "test", "tooling", "local", "developer", "workflow"},
}
def _extract_content(text: str) -> str:
"""Normalize and extract analyzable content from prompt."""
if not text:
return ""
# Collapse whitespace, strip
return " ".join(text.split()).strip()
def analyze_prompt(prompt: str) -> PromptAnalysis:
"""
Analyze prompt using pattern matching and keyword extraction.
No external APIs; pure symbolic reasoning.
"""
content = _extract_content(prompt).lower()
words = set(re.findall(r"\b[a-z0-9]{2,}\b", content))
analysis = PromptAnalysis(keywords=words)
# Intent: question vs statement vs request
if "?" in prompt:
analysis.intent = "question"
# Extract explicit questions
q_parts = re.split(r"\?+", prompt)
for part in q_parts[:-1]:
q = part.strip()
if len(q) > 10:
analysis.questions.append(q + "?")
elif any(w in content for w in ["how", "what", "why", "when", "where", "who"]):
analysis.intent = "question"
else:
analysis.intent = "statement" if len(prompt.split()) > 5 else "request"
# Entity-like phrases (title case or quoted)
entities = re.findall(r'"([^"]+)"', prompt)
entities += re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", prompt)
analysis.entities = list(dict.fromkeys(e for e in entities if len(e) > 2))[:10]
# Constraint signals
constraint_patterns = [
r"must\s+(\w[\w\s]+?)(?:\.|$)",
r"should\s+(\w[\w\s]+?)(?:\.|$)",
r"cannot\s+(\w[\w\s]+?)(?:\.|$)",
r"require[sd]?\s+(\w[\w\s]+?)(?:\.|$)",
r"constraint[s]?:\s*(\w[\w\s]+?)(?:\.|$)",
]
for pat in constraint_patterns:
for m in re.finditer(pat, prompt, re.I):
analysis.constraints.append(m.group(1).strip())
# Domain relevance per head
for hid, keywords in HEAD_DOMAIN_KEYWORDS.items():
if hid == HeadId.WITNESS:
continue
overlap = len(words & keywords) / max(len(keywords), 1)
analysis.domain_signals[hid.value] = min(1.0, overlap * 3)
return analysis
def _derive_claims_for_head(
head_id: HeadId,
analysis: PromptAnalysis,
prompt: str,
semantic_facts: list[dict[str, Any]] | None = None,
) -> list[HeadClaim]:
"""Derive atomic claims from analysis based on head domain."""
claims: list[HeadClaim] = []
persona = get_persona(head_id)
relevance = analysis.domain_signals.get(head_id.value, 0.3)
# Base claim from prompt summary
summary_claim = f"The prompt addresses: {analysis.intent}"
if analysis.entities:
summary_claim += f" involving {', '.join(analysis.entities[:3])}"
claims.append(
HeadClaim(
claim_text=summary_claim,
confidence=0.7 + relevance * 0.2,
evidence=[Citation(source_id="prompt_analysis", excerpt=prompt[:200], confidence=1.0)],
assumptions=[],
)
)
# Domain-specific claims
if head_id == HeadId.LOGIC:
claims.append(
HeadClaim(
claim_text="Logical consistency should be verified for any derived conclusions.",
confidence=0.8,
evidence=[],
assumptions=["Formal reasoning applies"],
)
)
elif head_id == HeadId.SECURITY:
if analysis.domain_signals.get(HeadId.SECURITY.value, 0) > 0.2:
claims.append(
HeadClaim(
claim_text="Security implications should be explicitly evaluated.",
confidence=0.85,
evidence=[],
assumptions=[],
)
)
elif head_id == HeadId.SAFETY:
claims.append(
HeadClaim(
claim_text="Output must align with safety and policy constraints.",
confidence=0.9,
evidence=[],
assumptions=[],
)
)
elif head_id == HeadId.STRATEGY and analysis.constraints:
claims.append(
HeadClaim(
claim_text=f"Constraints identified: {'; '.join(analysis.constraints[:2])}.",
confidence=0.75,
evidence=[],
assumptions=[],
)
)
# Memory-augmented: add claims from semantic facts when available
if semantic_facts:
for fact in semantic_facts[:3]:
stmt = fact.get("statement", "")
if stmt:
claims.append(
HeadClaim(
claim_text=stmt[:200],
confidence=0.6,
evidence=[
Citation(
source_id=fact.get("source", "semantic_memory"),
excerpt=stmt[:100],
confidence=0.8,
)
],
assumptions=["Stored fact; verify currency"],
)
)
return claims
def _derive_risks_for_head(head_id: HeadId, analysis: PromptAnalysis) -> list[HeadRisk]:
"""Identify risks based on head domain and analysis."""
risks: list[HeadRisk] = []
relevance = analysis.domain_signals.get(head_id.value, 0.3)
if relevance < 0.2:
risks.append(
HeadRisk(
description="Low domain relevance; analysis may be shallow for this head.",
severity="low",
)
)
if head_id == HeadId.SECURITY and relevance > 0.3:
risks.append(
HeadRisk(
description="Security-sensitive topic; require explicit threat assessment.",
severity="high",
)
)
if head_id == HeadId.SAFETY:
risks.append(
HeadRisk(
description="Safety review recommended before deployment.",
severity="medium",
)
)
return risks
def _synthesize_summary(head_id: HeadId, analysis: PromptAnalysis, claims: list[HeadClaim]) -> str:
"""Synthesize persona-appropriate summary from claims and analysis."""
persona = get_persona(head_id)
tone = persona.get("tone", "balanced")
expression = persona.get("expression", "neutral")
head_name = head_id.value.replace("_", " ").title()
parts: list[str] = []
if claims:
primary = claims[0].claim_text
parts.append(f"From {expression} perspective: {primary[:120]}{'...' if len(primary) > 120 else ''}.")
if analysis.questions:
parts.append(f"Addresses {len(analysis.questions)} explicit question(s).")
if analysis.constraints:
parts.append(f"Constraints noted: {len(analysis.constraints)}.")
if not parts:
parts.append(f"{head_name} head analysis: prompt analyzed with {tone} assessment.")
return " ".join(parts)
def produce_head_output(
head_id: HeadId,
prompt: str,
semantic_facts: list[dict[str, Any]] | None = None,
) -> HeadOutput:
"""
Produce structured HeadOutput using native reasoning only.
No external LLM calls. Uses symbolic analysis, domain logic, and persona-driven synthesis.
"""
if head_id == HeadId.WITNESS:
raise ValueError("Witness does not produce HeadOutput; use WitnessAgent")
analysis = analyze_prompt(prompt)
claims = _derive_claims_for_head(head_id, analysis, prompt, semantic_facts)
risks = _derive_risks_for_head(head_id, analysis)
summary = _synthesize_summary(head_id, analysis, claims)
# Recommended actions from analysis
actions: list[str] = []
if analysis.questions:
actions.append("Address each explicit question in the response.")
if analysis.constraints:
actions.append("Verify output satisfies stated constraints.")
if head_id in (HeadId.SECURITY, HeadId.SAFETY):
actions.append("Perform domain-specific review before finalizing.")
return HeadOutput(
head_id=head_id,
summary=summary,
claims=claims,
risks=risks,
questions=analysis.questions[:3] if analysis.questions else [],
recommended_actions=actions[:5] or ["Proceed with synthesis."],
tone_guidance=get_persona(head_id).get("tone", "balanced"),
)
def _domain_for_head(head_id: HeadId) -> str:
"""Map head to semantic memory domain."""
return head_id.value
class NativeReasoningProvider:
"""
Provider for native reasoning: produces HeadOutput without external APIs.
Optional memory integration: when semantic_memory is provided, retrieves
relevant facts to ground claims. When episodic_memory is provided, can
reference similar past outcomes (future: full retrieval).
"""
def __init__(
self,
semantic_memory: "SemanticMemory | None" = None,
episodic_memory: "EpisodicMemory | None" = None,
) -> None:
self._semantic = semantic_memory
self._episodic = episodic_memory
def produce_head_output(self, head_id: HeadId, prompt: str) -> HeadOutput:
"""Produce HeadOutput for the given head and prompt."""
return produce_head_output(
head_id,
prompt,
semantic_facts=self._get_relevant_facts(head_id) if self._semantic else None,
)
def _get_relevant_facts(self, head_id: HeadId, limit: int = 5) -> list[dict[str, Any]]:
"""Retrieve domain-relevant facts from semantic memory."""
if not self._semantic:
return []
domain = _domain_for_head(head_id)
return self._semantic.query(domain=domain, limit=limit)

View File

@@ -0,0 +1,49 @@
"""Dynamic recomposition: build higher-order insights with traceability."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from fusionagi.schemas.atomic import AtomicSemanticUnit
from fusionagi.reasoning.tot import ThoughtNode
@dataclass
class RecomposedResponse:
"""Recomposed response with traceability to atomic units."""
summary: str = ""
key_claims: list[str] = field(default_factory=list)
unit_refs: list[str] = field(default_factory=list)
confidence: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
def recompose(
thought_nodes: list[ThoughtNode],
atomic_units: list[AtomicSemanticUnit],
) -> RecomposedResponse:
"""Build higher-order insights from selected thought nodes."""
unit_refs: set[str] = set()
key_claims: list[str] = []
summaries: list[str] = []
for node in thought_nodes:
if node.thought:
summaries.append(node.thought[:200])
key_claims.append(node.thought[:150])
for uid in node.unit_refs:
unit_refs.add(uid)
summary = " ".join(summaries[:3]) if summaries else "No insights."
if len(summaries) > 3:
summary += " [truncated]"
avg_score = sum(n.score for n in thought_nodes) / len(thought_nodes) if thought_nodes else 0.0
return RecomposedResponse(
summary=summary,
key_claims=key_claims[:10],
unit_refs=list(unit_refs),
confidence=min(1.0, avg_score),
metadata={"node_count": len(thought_nodes), "unit_count": len(unit_refs)},
)

402
fusionagi/reasoning/tot.py Normal file
View File

@@ -0,0 +1,402 @@
"""Tree-of-thought: multi-branch reasoning with evaluation and selection.
Tree-of-Thought (ToT) extends Chain-of-Thought by exploring multiple reasoning paths
and selecting the best one. This is useful for complex problems where the first
reasoning path may not be optimal.
Key concepts:
- Branch: A single reasoning path
- ThoughtNode: Hierarchical node with depth, unit_refs, children (Super Big Brain)
- Evaluation: Scoring each branch for quality
- Selection: Choosing the best branch based on evaluation
- Pruning: Discarding low-quality branches early
"""
import json
import uuid
from dataclasses import dataclass, field
from typing import Any
from fusionagi.adapters.base import LLMAdapter
from fusionagi.reasoning.cot import run_chain_of_thought, build_cot_messages
from fusionagi._logger import logger
@dataclass
class ThoughtNode:
"""Hierarchical reasoning node: supports arbitrary depth, subtree independence, unit_refs."""
node_id: str = field(default_factory=lambda: f"node_{uuid.uuid4().hex[:12]}")
parent_id: str | None = None
thought: str = ""
trace: list[str] = field(default_factory=list)
score: float = 0.0
children: list["ThoughtNode"] = field(default_factory=list)
depth: int = 0
unit_refs: list[str] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
def expand_node(
node: ThoughtNode,
new_thought: str,
unit_refs: list[str] | None = None,
) -> ThoughtNode:
"""Create a child node under node."""
child = ThoughtNode(
parent_id=node.node_id,
thought=new_thought,
trace=node.trace + [new_thought],
depth=node.depth + 1,
unit_refs=unit_refs or list(node.unit_refs),
)
node.children.append(child)
return child
def prune_subtree(node: ThoughtNode, prune_threshold: float) -> ThoughtNode:
"""Remove children below prune_threshold; return node."""
node.children = [c for c in node.children if c.score >= prune_threshold]
for c in node.children:
prune_subtree(c, prune_threshold)
return node
def merge_subtrees(nodes: list[ThoughtNode], threshold: float = 0.8) -> ThoughtNode | None:
"""Merge sibling nodes when they converge on same conclusion (similarity > threshold)."""
if not nodes:
return None
if len(nodes) == 1:
return nodes[0]
best = max(nodes, key=lambda n: n.score)
for n in nodes:
if n is best:
continue
if n.score >= best.score * threshold:
best.thought += "\n[Alternative] " + n.thought[:200]
return best
@dataclass
class ThoughtBranch:
"""A single reasoning branch in the tree."""
branch_id: int
thought: str
trace: list[str]
score: float = 0.0
is_terminal: bool = False
children: list["ThoughtBranch"] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToTResult:
"""Result of Tree-of-Thought reasoning."""
best_response: str
best_trace: list[str]
best_score: float
all_branches: list[ThoughtBranch]
total_llm_calls: int
selection_reason: str
# System prompts for ToT
TOT_GENERATION_SYSTEM = """You are exploring different approaches to solve a problem.
Generate a distinct reasoning approach. Be creative and consider different angles.
State your thought process clearly step by step."""
TOT_EVALUATION_SYSTEM = """You are evaluating the quality of reasoning approaches.
Score each approach from 0 to 1 based on:
- Logical soundness (is the reasoning valid?)
- Completeness (does it address all aspects?)
- Clarity (is it easy to follow?)
- Practicality (can it be implemented?)
Output ONLY a JSON object: {"score": 0.X, "reason": "brief explanation"}"""
def _generate_branch(
adapter: LLMAdapter,
query: str,
context: str | None,
branch_num: int,
previous_branches: list[ThoughtBranch],
**kwargs: Any,
) -> ThoughtBranch:
"""Generate a single reasoning branch."""
# Build prompt that encourages diverse thinking
diversity_hint = ""
if previous_branches:
prev_summaries = [
f"Approach {b.branch_id}: {b.thought[:100]}..."
for b in previous_branches
]
diversity_hint = f"\n\nPrevious approaches tried:\n" + "\n".join(prev_summaries)
diversity_hint += "\n\nGenerate a DIFFERENT approach."
messages = [
{"role": "system", "content": TOT_GENERATION_SYSTEM},
{
"role": "user",
"content": f"Query: {query}{diversity_hint}" + (f"\n\nContext: {context}" if context else ""),
},
]
response = adapter.complete(messages, **kwargs)
return ThoughtBranch(
branch_id=branch_num,
thought=response,
trace=[response],
)
def _evaluate_branch(
adapter: LLMAdapter,
branch: ThoughtBranch,
query: str,
**kwargs: Any,
) -> float:
"""Evaluate a reasoning branch and return a score."""
messages = [
{"role": "system", "content": TOT_EVALUATION_SYSTEM},
{
"role": "user",
"content": f"Query: {query}\n\nReasoning approach:\n{branch.thought}\n\nScore this approach.",
},
]
response = adapter.complete(messages, **kwargs)
# Parse score from response
try:
# Try to extract JSON
if "{" in response:
json_start = response.index("{")
json_end = response.rindex("}") + 1
json_str = response[json_start:json_end]
result = json.loads(json_str)
score = float(result.get("score", 0.5))
branch.metadata["evaluation_reason"] = result.get("reason", "")
return max(0.0, min(1.0, score)) # Clamp to [0, 1]
except (json.JSONDecodeError, ValueError, KeyError):
pass
# Fallback: try to extract a number
import re
numbers = re.findall(r"0?\.\d+|1\.0|[01]", response)
if numbers:
try:
return max(0.0, min(1.0, float(numbers[0])))
except ValueError:
pass
return 0.5 # Default score if parsing fails
def _select_best_branch(branches: list[ThoughtBranch]) -> tuple[ThoughtBranch, str]:
"""Select the best branch based on scores."""
if not branches:
raise ValueError("No branches to select from")
if len(branches) == 1:
return branches[0], "Only one branch available"
# Sort by score descending
sorted_branches = sorted(branches, key=lambda b: b.score, reverse=True)
best = sorted_branches[0]
# Check if there's a clear winner
if len(sorted_branches) > 1:
score_diff = best.score - sorted_branches[1].score
if score_diff > 0.2:
reason = f"Clear winner with score {best.score:.2f} (next best: {sorted_branches[1].score:.2f})"
else:
reason = f"Selected highest score {best.score:.2f} among close alternatives"
else:
reason = f"Single branch with score {best.score:.2f}"
return best, reason
def run_tree_of_thought(
adapter: LLMAdapter,
query: str,
context: str | None = None,
max_branches: int = 3,
depth: int = 1,
prune_threshold: float = 0.3,
**kwargs: Any,
) -> tuple[str, list[str]]:
"""
Run Tree-of-Thought reasoning with multiple branches.
Args:
adapter: LLM adapter for generation and evaluation.
query: The question or problem to reason about.
context: Optional context to include.
max_branches: Maximum number of reasoning branches to explore.
depth: Number of refinement iterations (1 = single pass, 2+ = iterative refinement).
prune_threshold: Minimum score to keep a branch (branches below are pruned).
**kwargs: Additional arguments passed to adapter.complete().
Returns:
Tuple of (best_response, trace_list).
"""
if max_branches < 1:
max_branches = 1
if max_branches == 1:
# Fall back to simple CoT for single branch
return run_chain_of_thought(adapter, query, context=context, **kwargs)
logger.info(
"Starting Tree-of-Thought",
extra={"query_length": len(query), "max_branches": max_branches, "depth": depth},
)
total_llm_calls = 0
branches: list[ThoughtBranch] = []
# Generate initial branches
for i in range(max_branches):
branch = _generate_branch(adapter, query, context, i, branches, **kwargs)
total_llm_calls += 1
branches.append(branch)
# Evaluate all branches
for branch in branches:
branch.score = _evaluate_branch(adapter, branch, query, **kwargs)
total_llm_calls += 1
# Prune low-quality branches
branches = [b for b in branches if b.score >= prune_threshold]
if not branches:
# All branches pruned - fall back to CoT
logger.warning("All ToT branches pruned, falling back to CoT")
return run_chain_of_thought(adapter, query, context=context, **kwargs)
# Iterative refinement for depth > 1
for d in range(1, depth):
refined_branches = []
for branch in branches:
# Generate a refined version
refinement_prompt = f"""Original approach:
{branch.thought}
Score: {branch.score:.2f}
Feedback: {branch.metadata.get('evaluation_reason', 'N/A')}
Improve this approach based on the feedback. Make it more complete and rigorous."""
messages = [
{"role": "system", "content": TOT_GENERATION_SYSTEM},
{"role": "user", "content": f"Query: {query}\n\n{refinement_prompt}"},
]
refined_thought = adapter.complete(messages, **kwargs)
total_llm_calls += 1
refined_branch = ThoughtBranch(
branch_id=branch.branch_id,
thought=refined_thought,
trace=branch.trace + [f"[Refinement {d}] {refined_thought}"],
)
refined_branch.score = _evaluate_branch(adapter, refined_branch, query, **kwargs)
total_llm_calls += 1
# Keep the better version
if refined_branch.score > branch.score:
refined_branches.append(refined_branch)
else:
refined_branches.append(branch)
branches = refined_branches
# Select the best branch
best_branch, selection_reason = _select_best_branch(branches)
logger.info(
"Tree-of-Thought completed",
extra={
"best_score": best_branch.score,
"total_branches": len(branches),
"total_llm_calls": total_llm_calls,
},
)
# Build comprehensive trace
trace = [
f"[ToT Branch {best_branch.branch_id}] Score: {best_branch.score:.2f}",
best_branch.thought,
]
if best_branch.metadata.get("evaluation_reason"):
trace.append(f"[Evaluation] {best_branch.metadata['evaluation_reason']}")
trace.append(f"[Selection] {selection_reason}")
return best_branch.thought, trace
def run_tree_of_thought_detailed(
adapter: LLMAdapter,
query: str,
context: str | None = None,
max_branches: int = 3,
depth: int = 1,
prune_threshold: float = 0.3,
**kwargs: Any,
) -> ToTResult:
"""
Run Tree-of-Thought and return detailed results including all branches.
Same as run_tree_of_thought but returns a ToTResult with full information.
"""
if max_branches < 1:
max_branches = 1
if max_branches == 1:
response, trace = run_chain_of_thought(adapter, query, context=context, **kwargs)
single_branch = ThoughtBranch(branch_id=0, thought=response, trace=trace, score=0.5)
return ToTResult(
best_response=response,
best_trace=trace,
best_score=0.5,
all_branches=[single_branch],
total_llm_calls=1,
selection_reason="Single branch (CoT mode)",
)
total_llm_calls = 0
branches: list[ThoughtBranch] = []
# Generate and evaluate branches
for i in range(max_branches):
branch = _generate_branch(adapter, query, context, i, branches, **kwargs)
total_llm_calls += 1
branch.score = _evaluate_branch(adapter, branch, query, **kwargs)
total_llm_calls += 1
branches.append(branch)
all_branches = list(branches) # Keep all for result
# Prune
branches = [b for b in branches if b.score >= prune_threshold]
if not branches:
# Use best of all branches even if below threshold
branches = sorted(all_branches, key=lambda b: b.score, reverse=True)[:1]
# Select best
best_branch, selection_reason = _select_best_branch(branches)
return ToTResult(
best_response=best_branch.thought,
best_trace=best_branch.trace,
best_score=best_branch.score,
all_branches=all_branches,
total_llm_calls=total_llm_calls,
selection_reason=selection_reason,
)