Initial commit: add .gitignore and README
This commit is contained in:
56
fusionagi/reasoning/__init__.py
Normal file
56
fusionagi/reasoning/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Reasoning engine: chain-of-thought, tree-of-thought, and native symbolic reasoning."""
|
||||
|
||||
from fusionagi.reasoning.cot import (
|
||||
build_cot_messages,
|
||||
run_chain_of_thought,
|
||||
)
|
||||
from fusionagi.reasoning.tot import (
|
||||
run_tree_of_thought,
|
||||
run_tree_of_thought_detailed,
|
||||
ThoughtBranch,
|
||||
ThoughtNode,
|
||||
ToTResult,
|
||||
expand_node,
|
||||
prune_subtree,
|
||||
merge_subtrees,
|
||||
)
|
||||
from fusionagi.reasoning.native import (
|
||||
NativeReasoningProvider,
|
||||
analyze_prompt,
|
||||
produce_head_output,
|
||||
PromptAnalysis,
|
||||
)
|
||||
from fusionagi.reasoning.decomposition import decompose_recursive
|
||||
from fusionagi.reasoning.multi_path import generate_and_score_parallel
|
||||
from fusionagi.reasoning.recomposition import recompose, RecomposedResponse
|
||||
from fusionagi.reasoning.meta_reasoning import (
|
||||
challenge_assumptions,
|
||||
detect_contradictions,
|
||||
revisit_node,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"build_cot_messages",
|
||||
"run_chain_of_thought",
|
||||
"run_tree_of_thought",
|
||||
"run_tree_of_thought_detailed",
|
||||
"ThoughtBranch",
|
||||
"ThoughtNode",
|
||||
"ToTResult",
|
||||
"expand_node",
|
||||
"prune_subtree",
|
||||
"merge_subtrees",
|
||||
"NativeReasoningProvider",
|
||||
"analyze_prompt",
|
||||
"produce_head_output",
|
||||
"PromptAnalysis",
|
||||
"decompose_recursive",
|
||||
"load_context_for_reasoning",
|
||||
"build_compact_prompt",
|
||||
"generate_and_score_parallel",
|
||||
"recompose",
|
||||
"RecomposedResponse",
|
||||
"challenge_assumptions",
|
||||
"detect_contradictions",
|
||||
"revisit_node",
|
||||
]
|
||||
71
fusionagi/reasoning/context_loader.py
Normal file
71
fusionagi/reasoning/context_loader.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Retrieve-by-reference: load context for reasoning without token overflow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from fusionagi.schemas.atomic import AtomicSemanticUnit
|
||||
from fusionagi.memory.sharding import Shard, shard_context
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SemanticGraphLike(Protocol):
|
||||
"""Protocol for semantic graph access."""
|
||||
|
||||
def get_unit(self, unit_id: str) -> AtomicSemanticUnit | None: ...
|
||||
def query_units(self, unit_ids: list[str] | None = None, limit: int = 100) -> list[AtomicSemanticUnit]: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SharderLike(Protocol):
|
||||
"""Protocol for sharding."""
|
||||
|
||||
def __call__(self, units: list[AtomicSemanticUnit], max_cluster_size: int) -> list[Shard]: ...
|
||||
|
||||
|
||||
def load_context_for_reasoning(
|
||||
query_units: list[AtomicSemanticUnit],
|
||||
semantic_graph: SemanticGraphLike | None = None,
|
||||
sharder: SharderLike | None = None,
|
||||
max_cluster_size: int = 20,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch relevant shards/units by reference for reasoning.
|
||||
Returns structured context (unit IDs + summaries) rather than raw text.
|
||||
"""
|
||||
shard_fn = sharder or shard_context
|
||||
shards = shard_fn(query_units, max_cluster_size=max_cluster_size)
|
||||
unit_refs: list[str] = []
|
||||
unit_summaries: dict[str, str] = {}
|
||||
for u in query_units:
|
||||
unit_refs.append(u.unit_id)
|
||||
unit_summaries[u.unit_id] = u.content[:150]
|
||||
if semantic_graph:
|
||||
for s in shards:
|
||||
for uid in s.unit_ids:
|
||||
if uid not in unit_summaries:
|
||||
found = semantic_graph.get_unit(uid)
|
||||
if found:
|
||||
unit_summaries[uid] = found.content[:150]
|
||||
unit_refs.append(uid)
|
||||
return {"unit_refs": unit_refs, "shards": shards, "unit_summaries": unit_summaries}
|
||||
|
||||
|
||||
def build_compact_prompt(
|
||||
units: list[AtomicSemanticUnit],
|
||||
max_chars: int = 4000,
|
||||
) -> str:
|
||||
"""Materialize text for units that fit; rest stay as references."""
|
||||
parts: list[str] = []
|
||||
total = 0
|
||||
refs: list[str] = []
|
||||
for u in units:
|
||||
line = f"[{u.unit_id}] {u.content}\n"
|
||||
if total + len(line) <= max_chars:
|
||||
parts.append(line)
|
||||
total += len(line)
|
||||
else:
|
||||
refs.append(u.unit_id)
|
||||
if refs:
|
||||
parts.append(f"\n[References: {', '.join(refs[:20])}]")
|
||||
return "".join(parts) if parts else "[No units]"
|
||||
43
fusionagi/reasoning/cot.py
Normal file
43
fusionagi/reasoning/cot.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Chain-of-thought: prompt structure and trace storage."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
|
||||
COT_SYSTEM = """You reason step by step. For each step, state your thought clearly.
|
||||
Output your final conclusion or recommendation after your reasoning."""
|
||||
|
||||
|
||||
def build_cot_messages(
|
||||
query: str,
|
||||
context: str | None = None,
|
||||
trace_so_far: list[str] | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Build message list for chain-of-thought: system + optional context + user query."""
|
||||
messages: list[dict[str, str]] = [{"role": "system", "content": COT_SYSTEM}]
|
||||
if context:
|
||||
messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuery: {query}"})
|
||||
else:
|
||||
messages.append({"role": "user", "content": query})
|
||||
if trace_so_far:
|
||||
assistant_content = "\n".join(trace_so_far)
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
messages.append({"role": "user", "content": "Continue."})
|
||||
return messages
|
||||
|
||||
|
||||
def run_chain_of_thought(
|
||||
adapter: LLMAdapter,
|
||||
query: str,
|
||||
context: str | None = None,
|
||||
trace_so_far: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Run one CoT step; return (full_response, trace_entries).
|
||||
Trace entries can be stored and passed as trace_so_far for multi-step CoT.
|
||||
"""
|
||||
messages = build_cot_messages(query, context=context, trace_so_far=trace_so_far)
|
||||
response = adapter.complete(messages, **kwargs)
|
||||
trace = (trace_so_far or []) + [response]
|
||||
return response, trace
|
||||
207
fusionagi/reasoning/decomposition.py
Normal file
207
fusionagi/reasoning/decomposition.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Recursive semantic decomposition: split text into atomic units."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.reasoning.native import analyze_prompt
|
||||
from fusionagi.schemas.atomic import (
|
||||
AtomicSemanticUnit,
|
||||
AtomicUnitType,
|
||||
DecompositionResult,
|
||||
RelationType,
|
||||
SemanticRelation,
|
||||
)
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
def _make_unit_id(prefix: str = "asu") -> str:
|
||||
"""Generate unique unit ID."""
|
||||
return f"{prefix}_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
def _is_atomic(text: str, min_words: int = 3) -> bool:
|
||||
"""Check if text is irreducible (atomic)."""
|
||||
content = " ".join(text.split()).strip()
|
||||
if not content or len(content) < 10:
|
||||
return True
|
||||
words = len(content.split())
|
||||
return words <= min_words
|
||||
|
||||
|
||||
def _extract_questions(text: str) -> list[str]:
|
||||
"""Extract explicit questions from text."""
|
||||
questions: list[str] = []
|
||||
content = " ".join(text.split()).strip()
|
||||
q_parts = re.split(r"\?+", content)
|
||||
for part in q_parts[:-1]:
|
||||
q = part.strip()
|
||||
if len(q) > 10:
|
||||
questions.append(q + "?")
|
||||
if not questions and any(w in content.lower() for w in ["how", "what", "why", "when", "where", "who"]):
|
||||
questions.append(content)
|
||||
return questions[:5]
|
||||
|
||||
|
||||
def _extract_constraints(text: str) -> list[str]:
|
||||
"""Extract constraint signals from text."""
|
||||
constraints: list[str] = []
|
||||
patterns = [
|
||||
r"must\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
r"should\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
r"cannot\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
r"require[sd]?\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
r"constraint[s]?:\s*(\w[\w\s]+?)(?:\.|$)",
|
||||
r"assume[sd]?\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
]
|
||||
for pat in patterns:
|
||||
for m in re.finditer(pat, text, re.I):
|
||||
constraints.append(m.group(1).strip())
|
||||
return list(dict.fromkeys(constraints))[:10]
|
||||
|
||||
|
||||
def _extract_entities(text: str) -> list[str]:
|
||||
"""Extract entity-like phrases."""
|
||||
entities = re.findall(r'"([^"]+)"', text)
|
||||
entities += re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", text)
|
||||
return list(dict.fromkeys(e for e in entities if len(e) > 2))[:10]
|
||||
|
||||
|
||||
def decompose_recursive(
|
||||
text: str,
|
||||
max_depth: int = 3,
|
||||
parent_id: str | None = None,
|
||||
current_depth: int = 0,
|
||||
source_ref: str | None = None,
|
||||
) -> DecompositionResult:
|
||||
"""
|
||||
Recursively decompose text into atomic semantic units.
|
||||
|
||||
Extracts entities, constraints, intents, assumptions, questions; recurses
|
||||
on non-atomic segments. Integrates with native analyze_prompt for intent
|
||||
and domain signals.
|
||||
|
||||
Args:
|
||||
text: Input text to decompose.
|
||||
max_depth: Maximum recursion depth.
|
||||
parent_id: Parent unit ID for decomposition tree.
|
||||
current_depth: Current recursion depth.
|
||||
source_ref: Optional source reference.
|
||||
|
||||
Returns:
|
||||
DecompositionResult with units and relations.
|
||||
"""
|
||||
content = " ".join(text.split()).strip()
|
||||
if not content:
|
||||
return DecompositionResult(units=[], relations=[], depth=current_depth)
|
||||
|
||||
units: list[AtomicSemanticUnit] = []
|
||||
relations: list[SemanticRelation] = []
|
||||
|
||||
analysis = analyze_prompt(content)
|
||||
|
||||
# Root unit for this segment (if not already atomic)
|
||||
root_id = _make_unit_id()
|
||||
root_unit = AtomicSemanticUnit(
|
||||
unit_id=root_id,
|
||||
content=content[:500] + ("..." if len(content) > 500 else ""),
|
||||
type=AtomicUnitType.INTENT if analysis.intent == "question" else AtomicUnitType.FACT,
|
||||
confidence=0.8,
|
||||
parent_id=parent_id,
|
||||
source_ref=source_ref,
|
||||
metadata={"intent": analysis.intent},
|
||||
)
|
||||
units.append(root_unit)
|
||||
if parent_id:
|
||||
relations.append(
|
||||
SemanticRelation(from_id=parent_id, to_id=root_id, relation_type=RelationType.LOGICAL)
|
||||
)
|
||||
|
||||
# Extract questions as atomic units
|
||||
for q in _extract_questions(content):
|
||||
q_id = _make_unit_id()
|
||||
units.append(
|
||||
AtomicSemanticUnit(
|
||||
unit_id=q_id,
|
||||
content=q,
|
||||
type=AtomicUnitType.QUESTION,
|
||||
confidence=0.9,
|
||||
parent_id=root_id,
|
||||
source_ref=source_ref,
|
||||
)
|
||||
)
|
||||
relations.append(
|
||||
SemanticRelation(from_id=root_id, to_id=q_id, relation_type=RelationType.LOGICAL)
|
||||
)
|
||||
|
||||
# Extract constraints as atomic units
|
||||
for c in _extract_constraints(content):
|
||||
c_id = _make_unit_id()
|
||||
units.append(
|
||||
AtomicSemanticUnit(
|
||||
unit_id=c_id,
|
||||
content=c,
|
||||
type=AtomicUnitType.CONSTRAINT,
|
||||
confidence=0.85,
|
||||
parent_id=root_id,
|
||||
source_ref=source_ref,
|
||||
)
|
||||
)
|
||||
relations.append(
|
||||
SemanticRelation(from_id=root_id, to_id=c_id, relation_type=RelationType.LOGICAL)
|
||||
)
|
||||
|
||||
# Extract entities as atomic units
|
||||
for e in _extract_entities(content):
|
||||
e_id = _make_unit_id()
|
||||
units.append(
|
||||
AtomicSemanticUnit(
|
||||
unit_id=e_id,
|
||||
content=e,
|
||||
type=AtomicUnitType.FACT,
|
||||
confidence=0.9,
|
||||
parent_id=root_id,
|
||||
source_ref=source_ref,
|
||||
)
|
||||
)
|
||||
relations.append(
|
||||
SemanticRelation(from_id=root_id, to_id=e_id, relation_type=RelationType.LOGICAL)
|
||||
)
|
||||
|
||||
# If not atomic and depth allows, split and recurse
|
||||
if not _is_atomic(content, min_words=8) and current_depth < max_depth:
|
||||
sentences = re.split(r"[.!?]\s+", content)
|
||||
if len(sentences) > 1:
|
||||
for sent in sentences:
|
||||
sent = sent.strip()
|
||||
if len(sent) > 20:
|
||||
sub = decompose_recursive(
|
||||
sent,
|
||||
max_depth=max_depth,
|
||||
parent_id=root_id,
|
||||
current_depth=current_depth + 1,
|
||||
source_ref=source_ref,
|
||||
)
|
||||
units.extend(sub.units)
|
||||
relations.extend(sub.relations)
|
||||
|
||||
# Dedupe by unit_id
|
||||
seen: set[str] = set()
|
||||
unique_units: list[AtomicSemanticUnit] = []
|
||||
for u in units:
|
||||
if u.unit_id not in seen:
|
||||
seen.add(u.unit_id)
|
||||
unique_units.append(u)
|
||||
|
||||
logger.debug(
|
||||
"Decomposition complete",
|
||||
extra={"depth": current_depth, "units": len(unique_units), "relations": len(relations)},
|
||||
)
|
||||
|
||||
return DecompositionResult(
|
||||
units=unique_units,
|
||||
relations=relations,
|
||||
depth=current_depth,
|
||||
)
|
||||
85
fusionagi/reasoning/meta_reasoning.py
Normal file
85
fusionagi/reasoning/meta_reasoning.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Meta-reasoning: challenge assumptions, detect contradictions, revisit nodes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.atomic import AtomicSemanticUnit, AtomicUnitType
|
||||
from fusionagi.reasoning.tot import ThoughtNode, expand_node
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
def challenge_assumptions(
|
||||
units: list[AtomicSemanticUnit],
|
||||
current_conclusion: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Identify and flag assumptions in units that support the conclusion.
|
||||
"""
|
||||
flagged: list[str] = []
|
||||
conclusion_lower = current_conclusion.lower()
|
||||
|
||||
for u in units:
|
||||
if u.type == AtomicUnitType.ASSUMPTION:
|
||||
flagged.append(u.content)
|
||||
elif "assume" in u.content.lower() or "assumption" in u.content.lower():
|
||||
flagged.append(u.content)
|
||||
elif u.type == AtomicUnitType.CONSTRAINT:
|
||||
if any(w in conclusion_lower for w in ["must", "should", "require"]):
|
||||
flagged.append(f"Constraint may be assumed: {u.content[:100]}")
|
||||
|
||||
logger.debug("Assumptions flagged", extra={"count": len(flagged)})
|
||||
return flagged
|
||||
|
||||
|
||||
def detect_contradictions(
|
||||
units: list[AtomicSemanticUnit],
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Find conflicting units (heuristic: negation mismatch, same subject).
|
||||
"""
|
||||
neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't", "don't", "doesn't"}
|
||||
pairs: list[tuple[str, str]] = []
|
||||
|
||||
for i, a in enumerate(units):
|
||||
wa = set(a.content.lower().split())
|
||||
for b in units[i + 1:]:
|
||||
wb = set(b.content.lower().split())
|
||||
a_neg = bool(wa & neg_words)
|
||||
b_neg = bool(wb & neg_words)
|
||||
if a_neg != b_neg:
|
||||
overlap = len(wa & wb) / max(len(wa), 1)
|
||||
if overlap > 0.2:
|
||||
pairs.append((a.unit_id, b.unit_id))
|
||||
|
||||
logger.debug("Contradictions detected", extra={"count": len(pairs)})
|
||||
return pairs
|
||||
|
||||
|
||||
def revisit_node(
|
||||
tree: ThoughtNode | None,
|
||||
node_id: str,
|
||||
new_evidence: str,
|
||||
) -> ThoughtNode | None:
|
||||
"""
|
||||
Re-expand a node when new evidence arrives.
|
||||
Creates a new child with the new evidence.
|
||||
"""
|
||||
if tree is None:
|
||||
return None
|
||||
|
||||
def find_node(n: ThoughtNode) -> ThoughtNode | None:
|
||||
if n.node_id == node_id:
|
||||
return n
|
||||
for c in n.children:
|
||||
found = find_node(c)
|
||||
if found:
|
||||
return found
|
||||
return None
|
||||
|
||||
node = find_node(tree)
|
||||
if not node:
|
||||
return tree
|
||||
child = expand_node(node, new_evidence)
|
||||
child.metadata["revisit_evidence"] = new_evidence[:200]
|
||||
return tree
|
||||
50
fusionagi/reasoning/multi_path.py
Normal file
50
fusionagi/reasoning/multi_path.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Multi-path inference: parallel hypothesis generation and scoring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Callable
|
||||
|
||||
from fusionagi.schemas.atomic import AtomicSemanticUnit
|
||||
from fusionagi.reasoning.tot import ThoughtNode
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
def _score_coherence(node: ThoughtNode, _units: list[AtomicSemanticUnit]) -> float:
|
||||
return node.score * (0.9 + 0.1 * min(1, len(node.trace) / 5))
|
||||
|
||||
|
||||
def _score_consistency(node: ThoughtNode, units: list[AtomicSemanticUnit]) -> float:
|
||||
if not units:
|
||||
return 0.5
|
||||
unit_content = " ".join(u.content.lower() for u in units)
|
||||
thought_words = set(node.thought.lower().split())
|
||||
unit_words = set(unit_content.split())
|
||||
overlap = len(thought_words & unit_words) / max(len(thought_words), 1)
|
||||
return min(1.0, overlap * 2)
|
||||
|
||||
|
||||
def generate_and_score_parallel(
|
||||
hypotheses: list[str],
|
||||
units: list[AtomicSemanticUnit],
|
||||
score_fn: Callable[[ThoughtNode, list[AtomicSemanticUnit]], float] | None = None,
|
||||
) -> list[tuple[ThoughtNode, float]]:
|
||||
"""Score multiple hypotheses in parallel."""
|
||||
score_fn = score_fn or (lambda n, u: _score_coherence(n, u) * 0.5 + _score_consistency(n, u) * 0.5)
|
||||
|
||||
def score_one(h: str, i: int) -> tuple[ThoughtNode, float]:
|
||||
node = ThoughtNode(thought=h, trace=[h], unit_refs=[u.unit_id for u in units[:10]])
|
||||
s = score_fn(node, units)
|
||||
node.score = s
|
||||
return node, s
|
||||
|
||||
results: list[tuple[ThoughtNode, float]] = []
|
||||
with ThreadPoolExecutor(max_workers=min(len(hypotheses), 8)) as ex:
|
||||
futures = {ex.submit(score_one, h, i): i for i, h in enumerate(hypotheses)}
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
node, score = future.result()
|
||||
results.append((node, score))
|
||||
except Exception as e:
|
||||
logger.warning("Multi-path score failed", extra={"error": str(e)})
|
||||
return results
|
||||
319
fusionagi/reasoning/native.py
Normal file
319
fusionagi/reasoning/native.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Native reasoning engine: symbolic, rule-based analysis independent of external LLMs.
|
||||
|
||||
Produces structured HeadOutput from prompt analysis using:
|
||||
- Keyword and pattern extraction
|
||||
- Head-specific domain logic
|
||||
- Persona-driven synthesis
|
||||
- No external API calls
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.config.head_personas import get_persona
|
||||
from fusionagi.schemas.grounding import Citation
|
||||
from fusionagi.schemas.head import HeadClaim, HeadId, HeadOutput, HeadRisk
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAnalysis:
|
||||
"""Structured analysis of a user prompt from native reasoning."""
|
||||
|
||||
intent: str = ""
|
||||
entities: list[str] = field(default_factory=list)
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
questions: list[str] = field(default_factory=list)
|
||||
domain_signals: dict[str, float] = field(default_factory=dict)
|
||||
keywords: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
# Domain keywords per head: presence boosts relevance and shapes claims
|
||||
HEAD_DOMAIN_KEYWORDS: dict[HeadId, set[str]] = {
|
||||
HeadId.LOGIC: {"logic", "contradiction", "correct", "valid", "proof", "assumption", "therefore", "implies"},
|
||||
HeadId.RESEARCH: {"source", "cite", "reference", "evidence", "study", "paper", "find", "search"},
|
||||
HeadId.SYSTEMS: {"architecture", "scalability", "dependency", "system", "design", "component", "service"},
|
||||
HeadId.STRATEGY: {"strategy", "roadmap", "priority", "tradeoff", "plan", "goal", "long-term"},
|
||||
HeadId.PRODUCT: {"user", "ux", "product", "flow", "design", "experience", "interface"},
|
||||
HeadId.SECURITY: {"security", "auth", "threat", "vulnerability", "secret", "encrypt", "attack"},
|
||||
HeadId.SAFETY: {"safety", "harm", "policy", "ethical", "comply", "risk", "prevent"},
|
||||
HeadId.RELIABILITY: {"reliability", "slo", "failover", "observability", "test", "load", "uptime"},
|
||||
HeadId.COST: {"cost", "budget", "performance", "cache", "token", "efficient", "expensive"},
|
||||
HeadId.DATA: {"data", "schema", "privacy", "retention", "memory", "storage", "database"},
|
||||
HeadId.DEVEX: {"dev", "ci", "cd", "test", "tooling", "local", "developer", "workflow"},
|
||||
}
|
||||
|
||||
|
||||
def _extract_content(text: str) -> str:
|
||||
"""Normalize and extract analyzable content from prompt."""
|
||||
if not text:
|
||||
return ""
|
||||
# Collapse whitespace, strip
|
||||
return " ".join(text.split()).strip()
|
||||
|
||||
|
||||
def analyze_prompt(prompt: str) -> PromptAnalysis:
|
||||
"""
|
||||
Analyze prompt using pattern matching and keyword extraction.
|
||||
No external APIs; pure symbolic reasoning.
|
||||
"""
|
||||
content = _extract_content(prompt).lower()
|
||||
words = set(re.findall(r"\b[a-z0-9]{2,}\b", content))
|
||||
analysis = PromptAnalysis(keywords=words)
|
||||
|
||||
# Intent: question vs statement vs request
|
||||
if "?" in prompt:
|
||||
analysis.intent = "question"
|
||||
# Extract explicit questions
|
||||
q_parts = re.split(r"\?+", prompt)
|
||||
for part in q_parts[:-1]:
|
||||
q = part.strip()
|
||||
if len(q) > 10:
|
||||
analysis.questions.append(q + "?")
|
||||
elif any(w in content for w in ["how", "what", "why", "when", "where", "who"]):
|
||||
analysis.intent = "question"
|
||||
else:
|
||||
analysis.intent = "statement" if len(prompt.split()) > 5 else "request"
|
||||
|
||||
# Entity-like phrases (title case or quoted)
|
||||
entities = re.findall(r'"([^"]+)"', prompt)
|
||||
entities += re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", prompt)
|
||||
analysis.entities = list(dict.fromkeys(e for e in entities if len(e) > 2))[:10]
|
||||
|
||||
# Constraint signals
|
||||
constraint_patterns = [
|
||||
r"must\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
r"should\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
r"cannot\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
r"require[sd]?\s+(\w[\w\s]+?)(?:\.|$)",
|
||||
r"constraint[s]?:\s*(\w[\w\s]+?)(?:\.|$)",
|
||||
]
|
||||
for pat in constraint_patterns:
|
||||
for m in re.finditer(pat, prompt, re.I):
|
||||
analysis.constraints.append(m.group(1).strip())
|
||||
|
||||
# Domain relevance per head
|
||||
for hid, keywords in HEAD_DOMAIN_KEYWORDS.items():
|
||||
if hid == HeadId.WITNESS:
|
||||
continue
|
||||
overlap = len(words & keywords) / max(len(keywords), 1)
|
||||
analysis.domain_signals[hid.value] = min(1.0, overlap * 3)
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
def _derive_claims_for_head(
|
||||
head_id: HeadId,
|
||||
analysis: PromptAnalysis,
|
||||
prompt: str,
|
||||
semantic_facts: list[dict[str, Any]] | None = None,
|
||||
) -> list[HeadClaim]:
|
||||
"""Derive atomic claims from analysis based on head domain."""
|
||||
claims: list[HeadClaim] = []
|
||||
persona = get_persona(head_id)
|
||||
relevance = analysis.domain_signals.get(head_id.value, 0.3)
|
||||
|
||||
# Base claim from prompt summary
|
||||
summary_claim = f"The prompt addresses: {analysis.intent}"
|
||||
if analysis.entities:
|
||||
summary_claim += f" involving {', '.join(analysis.entities[:3])}"
|
||||
claims.append(
|
||||
HeadClaim(
|
||||
claim_text=summary_claim,
|
||||
confidence=0.7 + relevance * 0.2,
|
||||
evidence=[Citation(source_id="prompt_analysis", excerpt=prompt[:200], confidence=1.0)],
|
||||
assumptions=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Domain-specific claims
|
||||
if head_id == HeadId.LOGIC:
|
||||
claims.append(
|
||||
HeadClaim(
|
||||
claim_text="Logical consistency should be verified for any derived conclusions.",
|
||||
confidence=0.8,
|
||||
evidence=[],
|
||||
assumptions=["Formal reasoning applies"],
|
||||
)
|
||||
)
|
||||
elif head_id == HeadId.SECURITY:
|
||||
if analysis.domain_signals.get(HeadId.SECURITY.value, 0) > 0.2:
|
||||
claims.append(
|
||||
HeadClaim(
|
||||
claim_text="Security implications should be explicitly evaluated.",
|
||||
confidence=0.85,
|
||||
evidence=[],
|
||||
assumptions=[],
|
||||
)
|
||||
)
|
||||
elif head_id == HeadId.SAFETY:
|
||||
claims.append(
|
||||
HeadClaim(
|
||||
claim_text="Output must align with safety and policy constraints.",
|
||||
confidence=0.9,
|
||||
evidence=[],
|
||||
assumptions=[],
|
||||
)
|
||||
)
|
||||
elif head_id == HeadId.STRATEGY and analysis.constraints:
|
||||
claims.append(
|
||||
HeadClaim(
|
||||
claim_text=f"Constraints identified: {'; '.join(analysis.constraints[:2])}.",
|
||||
confidence=0.75,
|
||||
evidence=[],
|
||||
assumptions=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Memory-augmented: add claims from semantic facts when available
|
||||
if semantic_facts:
|
||||
for fact in semantic_facts[:3]:
|
||||
stmt = fact.get("statement", "")
|
||||
if stmt:
|
||||
claims.append(
|
||||
HeadClaim(
|
||||
claim_text=stmt[:200],
|
||||
confidence=0.6,
|
||||
evidence=[
|
||||
Citation(
|
||||
source_id=fact.get("source", "semantic_memory"),
|
||||
excerpt=stmt[:100],
|
||||
confidence=0.8,
|
||||
)
|
||||
],
|
||||
assumptions=["Stored fact; verify currency"],
|
||||
)
|
||||
)
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
def _derive_risks_for_head(head_id: HeadId, analysis: PromptAnalysis) -> list[HeadRisk]:
|
||||
"""Identify risks based on head domain and analysis."""
|
||||
risks: list[HeadRisk] = []
|
||||
relevance = analysis.domain_signals.get(head_id.value, 0.3)
|
||||
|
||||
if relevance < 0.2:
|
||||
risks.append(
|
||||
HeadRisk(
|
||||
description="Low domain relevance; analysis may be shallow for this head.",
|
||||
severity="low",
|
||||
)
|
||||
)
|
||||
if head_id == HeadId.SECURITY and relevance > 0.3:
|
||||
risks.append(
|
||||
HeadRisk(
|
||||
description="Security-sensitive topic; require explicit threat assessment.",
|
||||
severity="high",
|
||||
)
|
||||
)
|
||||
if head_id == HeadId.SAFETY:
|
||||
risks.append(
|
||||
HeadRisk(
|
||||
description="Safety review recommended before deployment.",
|
||||
severity="medium",
|
||||
)
|
||||
)
|
||||
|
||||
return risks
|
||||
|
||||
|
||||
def _synthesize_summary(head_id: HeadId, analysis: PromptAnalysis, claims: list[HeadClaim]) -> str:
|
||||
"""Synthesize persona-appropriate summary from claims and analysis."""
|
||||
persona = get_persona(head_id)
|
||||
tone = persona.get("tone", "balanced")
|
||||
expression = persona.get("expression", "neutral")
|
||||
head_name = head_id.value.replace("_", " ").title()
|
||||
|
||||
parts: list[str] = []
|
||||
if claims:
|
||||
primary = claims[0].claim_text
|
||||
parts.append(f"From {expression} perspective: {primary[:120]}{'...' if len(primary) > 120 else ''}.")
|
||||
if analysis.questions:
|
||||
parts.append(f"Addresses {len(analysis.questions)} explicit question(s).")
|
||||
if analysis.constraints:
|
||||
parts.append(f"Constraints noted: {len(analysis.constraints)}.")
|
||||
|
||||
if not parts:
|
||||
parts.append(f"{head_name} head analysis: prompt analyzed with {tone} assessment.")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def produce_head_output(
|
||||
head_id: HeadId,
|
||||
prompt: str,
|
||||
semantic_facts: list[dict[str, Any]] | None = None,
|
||||
) -> HeadOutput:
|
||||
"""
|
||||
Produce structured HeadOutput using native reasoning only.
|
||||
|
||||
No external LLM calls. Uses symbolic analysis, domain logic, and persona-driven synthesis.
|
||||
"""
|
||||
if head_id == HeadId.WITNESS:
|
||||
raise ValueError("Witness does not produce HeadOutput; use WitnessAgent")
|
||||
|
||||
analysis = analyze_prompt(prompt)
|
||||
claims = _derive_claims_for_head(head_id, analysis, prompt, semantic_facts)
|
||||
risks = _derive_risks_for_head(head_id, analysis)
|
||||
summary = _synthesize_summary(head_id, analysis, claims)
|
||||
|
||||
# Recommended actions from analysis
|
||||
actions: list[str] = []
|
||||
if analysis.questions:
|
||||
actions.append("Address each explicit question in the response.")
|
||||
if analysis.constraints:
|
||||
actions.append("Verify output satisfies stated constraints.")
|
||||
if head_id in (HeadId.SECURITY, HeadId.SAFETY):
|
||||
actions.append("Perform domain-specific review before finalizing.")
|
||||
|
||||
return HeadOutput(
|
||||
head_id=head_id,
|
||||
summary=summary,
|
||||
claims=claims,
|
||||
risks=risks,
|
||||
questions=analysis.questions[:3] if analysis.questions else [],
|
||||
recommended_actions=actions[:5] or ["Proceed with synthesis."],
|
||||
tone_guidance=get_persona(head_id).get("tone", "balanced"),
|
||||
)
|
||||
|
||||
|
||||
def _domain_for_head(head_id: HeadId) -> str:
|
||||
"""Map head to semantic memory domain."""
|
||||
return head_id.value
|
||||
|
||||
|
||||
class NativeReasoningProvider:
|
||||
"""
|
||||
Provider for native reasoning: produces HeadOutput without external APIs.
|
||||
|
||||
Optional memory integration: when semantic_memory is provided, retrieves
|
||||
relevant facts to ground claims. When episodic_memory is provided, can
|
||||
reference similar past outcomes (future: full retrieval).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
semantic_memory: "SemanticMemory | None" = None,
|
||||
episodic_memory: "EpisodicMemory | None" = None,
|
||||
) -> None:
|
||||
self._semantic = semantic_memory
|
||||
self._episodic = episodic_memory
|
||||
|
||||
def produce_head_output(self, head_id: HeadId, prompt: str) -> HeadOutput:
|
||||
"""Produce HeadOutput for the given head and prompt."""
|
||||
return produce_head_output(
|
||||
head_id,
|
||||
prompt,
|
||||
semantic_facts=self._get_relevant_facts(head_id) if self._semantic else None,
|
||||
)
|
||||
|
||||
def _get_relevant_facts(self, head_id: HeadId, limit: int = 5) -> list[dict[str, Any]]:
|
||||
"""Retrieve domain-relevant facts from semantic memory."""
|
||||
if not self._semantic:
|
||||
return []
|
||||
domain = _domain_for_head(head_id)
|
||||
return self._semantic.query(domain=domain, limit=limit)
|
||||
49
fusionagi/reasoning/recomposition.py
Normal file
49
fusionagi/reasoning/recomposition.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Dynamic recomposition: build higher-order insights with traceability."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.schemas.atomic import AtomicSemanticUnit
|
||||
from fusionagi.reasoning.tot import ThoughtNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecomposedResponse:
|
||||
"""Recomposed response with traceability to atomic units."""
|
||||
|
||||
summary: str = ""
|
||||
key_claims: list[str] = field(default_factory=list)
|
||||
unit_refs: list[str] = field(default_factory=list)
|
||||
confidence: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def recompose(
|
||||
thought_nodes: list[ThoughtNode],
|
||||
atomic_units: list[AtomicSemanticUnit],
|
||||
) -> RecomposedResponse:
|
||||
"""Build higher-order insights from selected thought nodes."""
|
||||
unit_refs: set[str] = set()
|
||||
key_claims: list[str] = []
|
||||
summaries: list[str] = []
|
||||
|
||||
for node in thought_nodes:
|
||||
if node.thought:
|
||||
summaries.append(node.thought[:200])
|
||||
key_claims.append(node.thought[:150])
|
||||
for uid in node.unit_refs:
|
||||
unit_refs.add(uid)
|
||||
|
||||
summary = " ".join(summaries[:3]) if summaries else "No insights."
|
||||
if len(summaries) > 3:
|
||||
summary += " [truncated]"
|
||||
avg_score = sum(n.score for n in thought_nodes) / len(thought_nodes) if thought_nodes else 0.0
|
||||
return RecomposedResponse(
|
||||
summary=summary,
|
||||
key_claims=key_claims[:10],
|
||||
unit_refs=list(unit_refs),
|
||||
confidence=min(1.0, avg_score),
|
||||
metadata={"node_count": len(thought_nodes), "unit_count": len(unit_refs)},
|
||||
)
|
||||
402
fusionagi/reasoning/tot.py
Normal file
402
fusionagi/reasoning/tot.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""Tree-of-thought: multi-branch reasoning with evaluation and selection.
|
||||
|
||||
Tree-of-Thought (ToT) extends Chain-of-Thought by exploring multiple reasoning paths
|
||||
and selecting the best one. This is useful for complex problems where the first
|
||||
reasoning path may not be optimal.
|
||||
|
||||
Key concepts:
|
||||
- Branch: A single reasoning path
|
||||
- ThoughtNode: Hierarchical node with depth, unit_refs, children (Super Big Brain)
|
||||
- Evaluation: Scoring each branch for quality
|
||||
- Selection: Choosing the best branch based on evaluation
|
||||
- Pruning: Discarding low-quality branches early
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.adapters.base import LLMAdapter
|
||||
from fusionagi.reasoning.cot import run_chain_of_thought, build_cot_messages
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThoughtNode:
|
||||
"""Hierarchical reasoning node: supports arbitrary depth, subtree independence, unit_refs."""
|
||||
|
||||
node_id: str = field(default_factory=lambda: f"node_{uuid.uuid4().hex[:12]}")
|
||||
parent_id: str | None = None
|
||||
thought: str = ""
|
||||
trace: list[str] = field(default_factory=list)
|
||||
score: float = 0.0
|
||||
children: list["ThoughtNode"] = field(default_factory=list)
|
||||
depth: int = 0
|
||||
unit_refs: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def expand_node(
|
||||
node: ThoughtNode,
|
||||
new_thought: str,
|
||||
unit_refs: list[str] | None = None,
|
||||
) -> ThoughtNode:
|
||||
"""Create a child node under node."""
|
||||
child = ThoughtNode(
|
||||
parent_id=node.node_id,
|
||||
thought=new_thought,
|
||||
trace=node.trace + [new_thought],
|
||||
depth=node.depth + 1,
|
||||
unit_refs=unit_refs or list(node.unit_refs),
|
||||
)
|
||||
node.children.append(child)
|
||||
return child
|
||||
|
||||
|
||||
def prune_subtree(node: ThoughtNode, prune_threshold: float) -> ThoughtNode:
|
||||
"""Remove children below prune_threshold; return node."""
|
||||
node.children = [c for c in node.children if c.score >= prune_threshold]
|
||||
for c in node.children:
|
||||
prune_subtree(c, prune_threshold)
|
||||
return node
|
||||
|
||||
|
||||
def merge_subtrees(nodes: list[ThoughtNode], threshold: float = 0.8) -> ThoughtNode | None:
|
||||
"""Merge sibling nodes when they converge on same conclusion (similarity > threshold)."""
|
||||
if not nodes:
|
||||
return None
|
||||
if len(nodes) == 1:
|
||||
return nodes[0]
|
||||
best = max(nodes, key=lambda n: n.score)
|
||||
for n in nodes:
|
||||
if n is best:
|
||||
continue
|
||||
if n.score >= best.score * threshold:
|
||||
best.thought += "\n[Alternative] " + n.thought[:200]
|
||||
return best
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThoughtBranch:
|
||||
"""A single reasoning branch in the tree."""
|
||||
|
||||
branch_id: int
|
||||
thought: str
|
||||
trace: list[str]
|
||||
score: float = 0.0
|
||||
is_terminal: bool = False
|
||||
children: list["ThoughtBranch"] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToTResult:
|
||||
"""Result of Tree-of-Thought reasoning."""
|
||||
|
||||
best_response: str
|
||||
best_trace: list[str]
|
||||
best_score: float
|
||||
all_branches: list[ThoughtBranch]
|
||||
total_llm_calls: int
|
||||
selection_reason: str
|
||||
|
||||
|
||||
# System prompts for ToT
|
||||
TOT_GENERATION_SYSTEM = """You are exploring different approaches to solve a problem.
|
||||
Generate a distinct reasoning approach. Be creative and consider different angles.
|
||||
State your thought process clearly step by step."""
|
||||
|
||||
TOT_EVALUATION_SYSTEM = """You are evaluating the quality of reasoning approaches.
|
||||
Score each approach from 0 to 1 based on:
|
||||
- Logical soundness (is the reasoning valid?)
|
||||
- Completeness (does it address all aspects?)
|
||||
- Clarity (is it easy to follow?)
|
||||
- Practicality (can it be implemented?)
|
||||
Output ONLY a JSON object: {"score": 0.X, "reason": "brief explanation"}"""
|
||||
|
||||
|
||||
def _generate_branch(
|
||||
adapter: LLMAdapter,
|
||||
query: str,
|
||||
context: str | None,
|
||||
branch_num: int,
|
||||
previous_branches: list[ThoughtBranch],
|
||||
**kwargs: Any,
|
||||
) -> ThoughtBranch:
|
||||
"""Generate a single reasoning branch."""
|
||||
# Build prompt that encourages diverse thinking
|
||||
diversity_hint = ""
|
||||
if previous_branches:
|
||||
prev_summaries = [
|
||||
f"Approach {b.branch_id}: {b.thought[:100]}..."
|
||||
for b in previous_branches
|
||||
]
|
||||
diversity_hint = f"\n\nPrevious approaches tried:\n" + "\n".join(prev_summaries)
|
||||
diversity_hint += "\n\nGenerate a DIFFERENT approach."
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": TOT_GENERATION_SYSTEM},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Query: {query}{diversity_hint}" + (f"\n\nContext: {context}" if context else ""),
|
||||
},
|
||||
]
|
||||
|
||||
response = adapter.complete(messages, **kwargs)
|
||||
|
||||
return ThoughtBranch(
|
||||
branch_id=branch_num,
|
||||
thought=response,
|
||||
trace=[response],
|
||||
)
|
||||
|
||||
|
||||
def _evaluate_branch(
|
||||
adapter: LLMAdapter,
|
||||
branch: ThoughtBranch,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""Evaluate a reasoning branch and return a score."""
|
||||
messages = [
|
||||
{"role": "system", "content": TOT_EVALUATION_SYSTEM},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Query: {query}\n\nReasoning approach:\n{branch.thought}\n\nScore this approach.",
|
||||
},
|
||||
]
|
||||
|
||||
response = adapter.complete(messages, **kwargs)
|
||||
|
||||
# Parse score from response
|
||||
try:
|
||||
# Try to extract JSON
|
||||
if "{" in response:
|
||||
json_start = response.index("{")
|
||||
json_end = response.rindex("}") + 1
|
||||
json_str = response[json_start:json_end]
|
||||
result = json.loads(json_str)
|
||||
score = float(result.get("score", 0.5))
|
||||
branch.metadata["evaluation_reason"] = result.get("reason", "")
|
||||
return max(0.0, min(1.0, score)) # Clamp to [0, 1]
|
||||
except (json.JSONDecodeError, ValueError, KeyError):
|
||||
pass
|
||||
|
||||
# Fallback: try to extract a number
|
||||
import re
|
||||
numbers = re.findall(r"0?\.\d+|1\.0|[01]", response)
|
||||
if numbers:
|
||||
try:
|
||||
return max(0.0, min(1.0, float(numbers[0])))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return 0.5 # Default score if parsing fails
|
||||
|
||||
|
||||
def _select_best_branch(branches: list[ThoughtBranch]) -> tuple[ThoughtBranch, str]:
|
||||
"""Select the best branch based on scores."""
|
||||
if not branches:
|
||||
raise ValueError("No branches to select from")
|
||||
|
||||
if len(branches) == 1:
|
||||
return branches[0], "Only one branch available"
|
||||
|
||||
# Sort by score descending
|
||||
sorted_branches = sorted(branches, key=lambda b: b.score, reverse=True)
|
||||
best = sorted_branches[0]
|
||||
|
||||
# Check if there's a clear winner
|
||||
if len(sorted_branches) > 1:
|
||||
score_diff = best.score - sorted_branches[1].score
|
||||
if score_diff > 0.2:
|
||||
reason = f"Clear winner with score {best.score:.2f} (next best: {sorted_branches[1].score:.2f})"
|
||||
else:
|
||||
reason = f"Selected highest score {best.score:.2f} among close alternatives"
|
||||
else:
|
||||
reason = f"Single branch with score {best.score:.2f}"
|
||||
|
||||
return best, reason
|
||||
|
||||
|
||||
def run_tree_of_thought(
|
||||
adapter: LLMAdapter,
|
||||
query: str,
|
||||
context: str | None = None,
|
||||
max_branches: int = 3,
|
||||
depth: int = 1,
|
||||
prune_threshold: float = 0.3,
|
||||
**kwargs: Any,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Run Tree-of-Thought reasoning with multiple branches.
|
||||
|
||||
Args:
|
||||
adapter: LLM adapter for generation and evaluation.
|
||||
query: The question or problem to reason about.
|
||||
context: Optional context to include.
|
||||
max_branches: Maximum number of reasoning branches to explore.
|
||||
depth: Number of refinement iterations (1 = single pass, 2+ = iterative refinement).
|
||||
prune_threshold: Minimum score to keep a branch (branches below are pruned).
|
||||
**kwargs: Additional arguments passed to adapter.complete().
|
||||
|
||||
Returns:
|
||||
Tuple of (best_response, trace_list).
|
||||
"""
|
||||
if max_branches < 1:
|
||||
max_branches = 1
|
||||
|
||||
if max_branches == 1:
|
||||
# Fall back to simple CoT for single branch
|
||||
return run_chain_of_thought(adapter, query, context=context, **kwargs)
|
||||
|
||||
logger.info(
|
||||
"Starting Tree-of-Thought",
|
||||
extra={"query_length": len(query), "max_branches": max_branches, "depth": depth},
|
||||
)
|
||||
|
||||
total_llm_calls = 0
|
||||
branches: list[ThoughtBranch] = []
|
||||
|
||||
# Generate initial branches
|
||||
for i in range(max_branches):
|
||||
branch = _generate_branch(adapter, query, context, i, branches, **kwargs)
|
||||
total_llm_calls += 1
|
||||
branches.append(branch)
|
||||
|
||||
# Evaluate all branches
|
||||
for branch in branches:
|
||||
branch.score = _evaluate_branch(adapter, branch, query, **kwargs)
|
||||
total_llm_calls += 1
|
||||
|
||||
# Prune low-quality branches
|
||||
branches = [b for b in branches if b.score >= prune_threshold]
|
||||
|
||||
if not branches:
|
||||
# All branches pruned - fall back to CoT
|
||||
logger.warning("All ToT branches pruned, falling back to CoT")
|
||||
return run_chain_of_thought(adapter, query, context=context, **kwargs)
|
||||
|
||||
# Iterative refinement for depth > 1
|
||||
for d in range(1, depth):
|
||||
refined_branches = []
|
||||
for branch in branches:
|
||||
# Generate a refined version
|
||||
refinement_prompt = f"""Original approach:
|
||||
{branch.thought}
|
||||
|
||||
Score: {branch.score:.2f}
|
||||
Feedback: {branch.metadata.get('evaluation_reason', 'N/A')}
|
||||
|
||||
Improve this approach based on the feedback. Make it more complete and rigorous."""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": TOT_GENERATION_SYSTEM},
|
||||
{"role": "user", "content": f"Query: {query}\n\n{refinement_prompt}"},
|
||||
]
|
||||
|
||||
refined_thought = adapter.complete(messages, **kwargs)
|
||||
total_llm_calls += 1
|
||||
|
||||
refined_branch = ThoughtBranch(
|
||||
branch_id=branch.branch_id,
|
||||
thought=refined_thought,
|
||||
trace=branch.trace + [f"[Refinement {d}] {refined_thought}"],
|
||||
)
|
||||
|
||||
refined_branch.score = _evaluate_branch(adapter, refined_branch, query, **kwargs)
|
||||
total_llm_calls += 1
|
||||
|
||||
# Keep the better version
|
||||
if refined_branch.score > branch.score:
|
||||
refined_branches.append(refined_branch)
|
||||
else:
|
||||
refined_branches.append(branch)
|
||||
|
||||
branches = refined_branches
|
||||
|
||||
# Select the best branch
|
||||
best_branch, selection_reason = _select_best_branch(branches)
|
||||
|
||||
logger.info(
|
||||
"Tree-of-Thought completed",
|
||||
extra={
|
||||
"best_score": best_branch.score,
|
||||
"total_branches": len(branches),
|
||||
"total_llm_calls": total_llm_calls,
|
||||
},
|
||||
)
|
||||
|
||||
# Build comprehensive trace
|
||||
trace = [
|
||||
f"[ToT Branch {best_branch.branch_id}] Score: {best_branch.score:.2f}",
|
||||
best_branch.thought,
|
||||
]
|
||||
if best_branch.metadata.get("evaluation_reason"):
|
||||
trace.append(f"[Evaluation] {best_branch.metadata['evaluation_reason']}")
|
||||
trace.append(f"[Selection] {selection_reason}")
|
||||
|
||||
return best_branch.thought, trace
|
||||
|
||||
|
||||
def run_tree_of_thought_detailed(
|
||||
adapter: LLMAdapter,
|
||||
query: str,
|
||||
context: str | None = None,
|
||||
max_branches: int = 3,
|
||||
depth: int = 1,
|
||||
prune_threshold: float = 0.3,
|
||||
**kwargs: Any,
|
||||
) -> ToTResult:
|
||||
"""
|
||||
Run Tree-of-Thought and return detailed results including all branches.
|
||||
|
||||
Same as run_tree_of_thought but returns a ToTResult with full information.
|
||||
"""
|
||||
if max_branches < 1:
|
||||
max_branches = 1
|
||||
|
||||
if max_branches == 1:
|
||||
response, trace = run_chain_of_thought(adapter, query, context=context, **kwargs)
|
||||
single_branch = ThoughtBranch(branch_id=0, thought=response, trace=trace, score=0.5)
|
||||
return ToTResult(
|
||||
best_response=response,
|
||||
best_trace=trace,
|
||||
best_score=0.5,
|
||||
all_branches=[single_branch],
|
||||
total_llm_calls=1,
|
||||
selection_reason="Single branch (CoT mode)",
|
||||
)
|
||||
|
||||
total_llm_calls = 0
|
||||
branches: list[ThoughtBranch] = []
|
||||
|
||||
# Generate and evaluate branches
|
||||
for i in range(max_branches):
|
||||
branch = _generate_branch(adapter, query, context, i, branches, **kwargs)
|
||||
total_llm_calls += 1
|
||||
branch.score = _evaluate_branch(adapter, branch, query, **kwargs)
|
||||
total_llm_calls += 1
|
||||
branches.append(branch)
|
||||
|
||||
all_branches = list(branches) # Keep all for result
|
||||
|
||||
# Prune
|
||||
branches = [b for b in branches if b.score >= prune_threshold]
|
||||
|
||||
if not branches:
|
||||
# Use best of all branches even if below threshold
|
||||
branches = sorted(all_branches, key=lambda b: b.score, reverse=True)[:1]
|
||||
|
||||
# Select best
|
||||
best_branch, selection_reason = _select_best_branch(branches)
|
||||
|
||||
return ToTResult(
|
||||
best_response=best_branch.thought,
|
||||
best_trace=best_branch.trace,
|
||||
best_score=best_branch.score,
|
||||
all_branches=all_branches,
|
||||
total_llm_calls=total_llm_calls,
|
||||
selection_reason=selection_reason,
|
||||
)
|
||||
Reference in New Issue
Block a user