Initial commit: add .gitignore and README
This commit is contained in:
147
fusionagi/api/routes/sessions.py
Normal file
147
fusionagi/api/routes/sessions.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Session and prompt routes."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
|
||||
from fusionagi.api.dependencies import get_orchestrator, get_session_store, get_event_bus, get_safety_pipeline
|
||||
from fusionagi.api.websocket import handle_stream
|
||||
from fusionagi.core import run_dvadasa, select_heads_for_complexity, extract_sources_from_head_outputs
|
||||
from fusionagi.schemas.commands import parse_user_input, UserIntent
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _ensure_init():
|
||||
from fusionagi.api.dependencies import ensure_initialized
|
||||
ensure_initialized()
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_session(user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Create a new session."""
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
if not store:
|
||||
raise HTTPException(status_code=503, detail="Session store not initialized")
|
||||
session_id = str(uuid.uuid4())
|
||||
store.create(session_id, user_id)
|
||||
return {"session_id": session_id, "user_id": user_id}
|
||||
|
||||
|
||||
@router.post("/{session_id}/prompt")
|
||||
def submit_prompt(session_id: str, body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Submit a prompt and receive FinalResponse (sync)."""
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
orch = get_orchestrator()
|
||||
bus = get_event_bus()
|
||||
if not store or not orch:
|
||||
raise HTTPException(status_code=503, detail="Service not initialized")
|
||||
|
||||
sess = store.get(session_id)
|
||||
if not sess:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
prompt = body.get("prompt", "")
|
||||
parsed = parse_user_input(prompt)
|
||||
|
||||
if not prompt or not parsed.cleaned_prompt.strip():
|
||||
if parsed.intent in (UserIntent.SHOW_DISSENT, UserIntent.RERUN_RISK, UserIntent.EXPLAIN_REASONING, UserIntent.SOURCES):
|
||||
hist = sess.get("history", [])
|
||||
if hist:
|
||||
prompt = hist[-1].get("prompt", "")
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=400, detail="No previous prompt; provide a prompt for this command")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="prompt is required")
|
||||
|
||||
effective_prompt = parsed.cleaned_prompt.strip() or prompt
|
||||
pipeline = get_safety_pipeline()
|
||||
if pipeline:
|
||||
pre_result = pipeline.pre_check(effective_prompt)
|
||||
if not pre_result.allowed:
|
||||
raise HTTPException(status_code=400, detail=pre_result.reason or "Input moderation failed")
|
||||
|
||||
task_id = orch.submit_task(goal=effective_prompt[:200])
|
||||
|
||||
# Dynamic head selection
|
||||
head_ids = select_heads_for_complexity(effective_prompt)
|
||||
if parsed.intent.value == "head_strategy" and parsed.head_id:
|
||||
head_ids = [parsed.head_id]
|
||||
|
||||
force_second = parsed.intent == UserIntent.RERUN_RISK
|
||||
return_heads = parsed.intent == UserIntent.SOURCES
|
||||
|
||||
result = run_dvadasa(
|
||||
orchestrator=orch,
|
||||
task_id=task_id,
|
||||
user_prompt=effective_prompt,
|
||||
parsed=parsed,
|
||||
head_ids=head_ids if parsed.intent.value != "normal" or body.get("use_all_heads") else None,
|
||||
event_bus=bus,
|
||||
force_second_pass=force_second,
|
||||
return_head_outputs=return_heads,
|
||||
)
|
||||
|
||||
if return_heads and isinstance(result, tuple):
|
||||
final, head_outputs = result
|
||||
else:
|
||||
final = result
|
||||
head_outputs = []
|
||||
|
||||
if not final:
|
||||
raise HTTPException(status_code=500, detail="Failed to produce response")
|
||||
|
||||
if pipeline:
|
||||
post_result = pipeline.post_check(final.final_answer)
|
||||
if not post_result.passed:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Output scan failed: {', '.join(post_result.flags)}",
|
||||
)
|
||||
|
||||
entry = {
|
||||
"prompt": effective_prompt,
|
||||
"final_answer": final.final_answer,
|
||||
"confidence_score": final.confidence_score,
|
||||
"head_contributions": final.head_contributions,
|
||||
}
|
||||
store.append_history(session_id, entry)
|
||||
|
||||
response: dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"final_answer": final.final_answer,
|
||||
"transparency_report": final.transparency_report.model_dump(),
|
||||
"head_contributions": final.head_contributions,
|
||||
"confidence_score": final.confidence_score,
|
||||
}
|
||||
if parsed.intent == UserIntent.SHOW_DISSENT:
|
||||
response["response_mode"] = "show_dissent"
|
||||
response["disputed_claims"] = final.transparency_report.agreement_map.disputed_claims
|
||||
elif parsed.intent == UserIntent.EXPLAIN_REASONING:
|
||||
response["response_mode"] = "explain"
|
||||
elif parsed.intent == UserIntent.SOURCES and head_outputs:
|
||||
response["sources"] = extract_sources_from_head_outputs(head_outputs)
|
||||
return response
|
||||
|
||||
|
||||
@router.websocket("/{session_id}/stream")
|
||||
async def stream_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||
"""WebSocket for streaming Dvādaśa response. Send {\"prompt\": \"...\"} to start."""
|
||||
await websocket.accept()
|
||||
try:
|
||||
data = await websocket.receive_json()
|
||||
prompt = data.get("prompt", "")
|
||||
async def send_evt(evt: dict) -> None:
|
||||
await websocket.send_json(evt)
|
||||
await handle_stream(session_id, prompt, send_evt)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
try:
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user