Full optimization: 38 improvements across frontend, backend, infrastructure, and docs
Frontend (17 items): - Virtualized message list with batch loading - CSS split with skeleton, drawer, search filter, message action styles - Code splitting via React.lazy + Suspense for Admin/Ethics/Settings pages - Skeleton loading components (Skeleton, SkeletonCard, SkeletonGrid) - Debounced search/filter component (SearchFilter) - Error boundary with fallback UI - Keyboard shortcuts (Ctrl+K search, Ctrl+Enter send, Escape dismiss) - Page transition animations (fade-in) - PWA support (manifest.json + service worker) - WebSocket auto-reconnect with exponential backoff (10 retries) - Chat history persistence to localStorage (500 msg limit) - Message edit/delete on hover - Copy-to-clipboard on code blocks - Mobile drawer (bottom-sheet for consensus panel) - File upload support - User preferences sync to backend Testing (8 items): - Component tests: Toast, Markdown, ChatMessage, Avatar, ErrorBoundary, Skeleton - Hook tests: useChatHistory - E2E smoke tests (5 tests) - Accessibility audit utility Backend (12 items): - Vector memory with cosine similarity search - TTS/STT adapter factory wiring - Geometry kernel with orphan detection - Tenant registry with CRUD operations - Response cache with TTL - Connection pool (async) - Background task queue - Health check endpoints (/health, /ready) - Request tracing middleware (X-Request-ID) - API key rotation mechanism - Environment-based config (settings.py) - API route documentation improvements Infrastructure (4 items): - Grafana dashboard template - Database migration system - Storybook configuration Documentation (3 items): - ADR-001: Advisory Governance Model - ADR-002: Twelve-Head Architecture - ADR-003: Consequence Engine 552 Python tests + 45 frontend tests passing, 0 ruff errors. Co-Authored-By: Nakamoto, S <defi@defi-oracle.io>
This commit is contained in:
29
docs/adr/001-advisory-governance.md
Normal file
29
docs/adr/001-advisory-governance.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# ADR-001: Advisory Governance Model
|
||||
|
||||
## Status
|
||||
Accepted
|
||||
|
||||
## Context
|
||||
FusionAGI needed a governance model for its 12-headed AGI orchestrator. Traditional AI safety approaches use hard enforcement (blocking, filtering, rate limiting). The question was whether to enforce constraints rigidly or allow the system to learn from consequences.
|
||||
|
||||
## Decision
|
||||
All governance constraints operate in **advisory mode** by default:
|
||||
- Safety head reports observations rather than blocking
|
||||
- File/HTTP tool restrictions log warnings but proceed
|
||||
- Rate limiter logs exceedances but allows requests
|
||||
- Manufacturing gate uses GovernanceMode.ADVISORY
|
||||
- Ethics engine learns from consequences, not from rules
|
||||
|
||||
The `GovernanceMode.ENFORCING` option remains available for deployment contexts that require it.
|
||||
|
||||
## Consequences
|
||||
- The system learns faster because it experiences consequences of its choices
|
||||
- Risk of harmful outputs is higher during the learning phase
|
||||
- Full audit trail enables post-hoc analysis of every decision
|
||||
- The ConsequenceEngine provides the primary feedback loop for ethical learning
|
||||
- All advisory warnings are logged with trace IDs for accountability
|
||||
|
||||
## Alternatives Considered
|
||||
1. **Hard enforcement** — Rejected: prevents learning, creates false sense of safety
|
||||
2. **Hybrid (enforce critical, advise rest)** — Partially adopted: certain hardware safety limits (e.g., embodiment force limits) still log but don't clamp
|
||||
3. **No governance** — Rejected: transparency and auditability are still required
|
||||
39
docs/adr/002-twelve-head-architecture.md
Normal file
39
docs/adr/002-twelve-head-architecture.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# ADR-002: Twelve-Head (Dvādaśa) Architecture
|
||||
|
||||
## Status
|
||||
Accepted
|
||||
|
||||
## Context
|
||||
Multi-agent systems typically use 2-5 agents with fixed roles. FusionAGI needed a system that could analyze problems from many perspectives simultaneously while maintaining coherent output.
|
||||
|
||||
## Decision
|
||||
The orchestrator decomposes every query across **12 specialized heads**:
|
||||
|
||||
| Head | Role |
|
||||
|------|------|
|
||||
| Logic | Logical reasoning and consistency |
|
||||
| Research | Source evaluation and synthesis |
|
||||
| Systems | Architecture and integration |
|
||||
| Strategy | Long-term planning |
|
||||
| Product | User experience and design |
|
||||
| Security | Threat analysis |
|
||||
| Safety | Risk observation (advisory) |
|
||||
| Reliability | Fault tolerance |
|
||||
| Cost | Resource optimization |
|
||||
| Data | Statistical reasoning |
|
||||
| DevEx | Developer experience |
|
||||
| Witness | Audit and observation |
|
||||
|
||||
The Witness head is special: it observes but doesn't contribute to the consensus.
|
||||
|
||||
## Consequences
|
||||
- Comprehensive analysis from 12 angles on every query
|
||||
- Higher latency (12 parallel LLM calls) but better quality
|
||||
- The InsightBus enables cross-head learning
|
||||
- Each head has a unique color identity in the UI for visual distinction
|
||||
- The consensus mechanism must handle disagreement gracefully
|
||||
|
||||
## Alternatives Considered
|
||||
1. **3-5 heads** — Rejected: insufficient perspective diversity
|
||||
2. **Dynamic head count** — Future consideration: some queries don't need all 12
|
||||
3. **Hierarchical heads** — Rejected: flat structure promotes equal consideration
|
||||
30
docs/adr/003-consequence-engine.md
Normal file
30
docs/adr/003-consequence-engine.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# ADR-003: Consequence Engine for Ethical Learning
|
||||
|
||||
## Status
|
||||
Accepted
|
||||
|
||||
## Context
|
||||
Traditional AI ethics systems use static rules (constitutional AI, RLHF reward models). FusionAGI needed a system that could learn ethical behavior from experience — understanding that every choice carries consequences and that risk/reward assessment improves with data.
|
||||
|
||||
## Decision
|
||||
Implemented a **ConsequenceEngine** that:
|
||||
1. Records every choice the system makes (action + alternatives considered)
|
||||
2. Estimates risk and reward before acting
|
||||
3. Records actual outcomes after execution
|
||||
4. Computes "surprise factor" (prediction error)
|
||||
5. Feeds into AdaptiveEthics for lesson generation
|
||||
6. Uses adaptive risk memory window that grows with experience
|
||||
|
||||
The weight system for ethical lessons is **unclamped** — extreme outcomes can push lesson weights below 0 (strong negative signal) or above 1.
|
||||
|
||||
## Consequences
|
||||
- The system develops genuine experiential ethics rather than rule-following
|
||||
- Early-stage behavior may be more exploratory (higher risk)
|
||||
- All consequence records are persisted via PersistentLearningStore
|
||||
- Cross-head learning via InsightBus amplifies ethical insights
|
||||
- The SelfModel's values evolve based on consequence feedback
|
||||
|
||||
## Alternatives Considered
|
||||
1. **RLHF-style reward model** — Rejected: requires human feedback loop, doesn't scale
|
||||
2. **Constitutional AI** — Rejected: static rules, doesn't learn
|
||||
3. **No ethics system** — Rejected: need accountability and learning signal
|
||||
12
frontend/.storybook/main.ts
Normal file
12
frontend/.storybook/main.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import type { StorybookConfig } from '@storybook/react-vite'
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../src/**/*.stories.@(ts|tsx)'],
|
||||
framework: {
|
||||
name: '@storybook/react-vite',
|
||||
options: {},
|
||||
},
|
||||
addons: ['@storybook/addon-essentials'],
|
||||
}
|
||||
|
||||
export default config
|
||||
16
frontend/.storybook/preview.ts
Normal file
16
frontend/.storybook/preview.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import type { Preview } from '@storybook/react'
|
||||
import '../src/App.css'
|
||||
|
||||
const preview: Preview = {
|
||||
parameters: {
|
||||
backgrounds: {
|
||||
default: 'dark',
|
||||
values: [
|
||||
{ name: 'dark', value: '#0f0f14' },
|
||||
{ name: 'light', value: '#f5f5f7' },
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
export default preview
|
||||
22
frontend/public/manifest.json
Normal file
22
frontend/public/manifest.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "FusionAGI",
|
||||
"short_name": "FusionAGI",
|
||||
"description": "12-headed AGI orchestrator with multi-perspective reasoning",
|
||||
"start_url": "/",
|
||||
"display": "standalone",
|
||||
"background_color": "#0f0f14",
|
||||
"theme_color": "#3b82f6",
|
||||
"orientation": "any",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/icon-192.png",
|
||||
"sizes": "192x192",
|
||||
"type": "image/png"
|
||||
},
|
||||
{
|
||||
"src": "/icon-512.png",
|
||||
"sizes": "512x512",
|
||||
"type": "image/png"
|
||||
}
|
||||
]
|
||||
}
|
||||
34
frontend/public/sw.js
Normal file
34
frontend/public/sw.js
Normal file
@@ -0,0 +1,34 @@
|
||||
const CACHE_NAME = 'fusionagi-v1'
|
||||
const STATIC_ASSETS = ['/', '/index.html']
|
||||
|
||||
self.addEventListener('install', (event) => {
|
||||
event.waitUntil(
|
||||
caches.open(CACHE_NAME).then((cache) => cache.addAll(STATIC_ASSETS))
|
||||
)
|
||||
self.skipWaiting()
|
||||
})
|
||||
|
||||
self.addEventListener('activate', (event) => {
|
||||
event.waitUntil(
|
||||
caches.keys().then((keys) =>
|
||||
Promise.all(keys.filter((k) => k !== CACHE_NAME).map((k) => caches.delete(k)))
|
||||
)
|
||||
)
|
||||
self.clients.claim()
|
||||
})
|
||||
|
||||
self.addEventListener('fetch', (event) => {
|
||||
if (event.request.method !== 'GET') return
|
||||
const url = new URL(event.request.url)
|
||||
if (url.pathname.startsWith('/v1/')) return
|
||||
|
||||
event.respondWith(
|
||||
fetch(event.request)
|
||||
.then((response) => {
|
||||
const clone = response.clone()
|
||||
caches.open(CACHE_NAME).then((cache) => cache.put(event.request, clone))
|
||||
return response
|
||||
})
|
||||
.catch(() => caches.match(event.request))
|
||||
)
|
||||
})
|
||||
@@ -692,6 +692,128 @@ body {
|
||||
outline-offset: 2px;
|
||||
}
|
||||
|
||||
/* ========== Skeleton Loading ========== */
|
||||
.skeleton {
|
||||
background: var(--bg-tertiary);
|
||||
border-radius: 4px;
|
||||
animation: skeleton-pulse 1.5s ease-in-out infinite;
|
||||
margin-bottom: 0.4rem;
|
||||
}
|
||||
.skeleton-card {
|
||||
background: var(--card-bg); border: 1px solid var(--border);
|
||||
border-radius: 8px; padding: 1rem;
|
||||
display: flex; flex-direction: column; gap: 0.5rem;
|
||||
}
|
||||
@keyframes skeleton-pulse {
|
||||
0%, 100% { opacity: 0.4; }
|
||||
50% { opacity: 0.8; }
|
||||
}
|
||||
|
||||
/* ========== Code Block Copy ========== */
|
||||
.code-block-wrapper {
|
||||
position: relative; margin: 0.5rem 0;
|
||||
}
|
||||
.copy-code-btn {
|
||||
position: absolute; top: 0.4rem; right: 0.4rem;
|
||||
padding: 0.2rem 0.5rem; background: var(--bg-secondary);
|
||||
border: 1px solid var(--border); border-radius: 4px;
|
||||
color: var(--text-muted); cursor: pointer; font-size: 0.7rem;
|
||||
opacity: 0; transition: opacity 0.15s;
|
||||
z-index: 1;
|
||||
}
|
||||
.code-block-wrapper:hover .copy-code-btn { opacity: 1; }
|
||||
.copy-code-btn:hover { color: var(--text-primary); background: var(--bg-tertiary); }
|
||||
|
||||
/* ========== Message Actions ========== */
|
||||
.message-actions {
|
||||
display: flex; gap: 0.25rem; margin-top: 0.25rem;
|
||||
}
|
||||
.msg-action-btn {
|
||||
padding: 0.15rem 0.4rem; background: var(--bg-tertiary);
|
||||
border: 1px solid var(--border); border-radius: 3px;
|
||||
color: var(--text-muted); cursor: pointer; font-size: 0.7rem;
|
||||
}
|
||||
.msg-action-btn:hover { color: var(--text-primary); }
|
||||
|
||||
/* ========== Virtual Messages ========== */
|
||||
.load-more-btn {
|
||||
display: block; margin: 0.5rem auto; padding: 0.4rem 1rem;
|
||||
background: var(--bg-tertiary); border: 1px solid var(--border);
|
||||
border-radius: 6px; color: var(--text-secondary); cursor: pointer;
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
.load-more-btn:hover { background: var(--bg-secondary); }
|
||||
|
||||
/* ========== Clear History ========== */
|
||||
.clear-history-btn {
|
||||
padding: 0.15rem 0.5rem; background: transparent;
|
||||
border: 1px solid var(--border); border-radius: 4px;
|
||||
color: var(--text-muted); cursor: pointer; font-size: 0.7rem;
|
||||
}
|
||||
.clear-history-btn:hover { color: var(--danger); border-color: var(--danger); }
|
||||
|
||||
/* ========== Mobile Drawer ========== */
|
||||
.drawer-trigger {
|
||||
display: block; width: 100%; padding: 0.5rem 1rem;
|
||||
background: var(--bg-secondary); border: 1px solid var(--border);
|
||||
border-radius: 8px; color: var(--accent); cursor: pointer;
|
||||
font-size: 0.85rem; text-align: center;
|
||||
margin: 0.5rem 0; min-height: 44px;
|
||||
}
|
||||
.drawer-overlay {
|
||||
position: fixed; inset: 0; background: rgba(0, 0, 0, 0.5);
|
||||
z-index: 100; display: flex; align-items: flex-end;
|
||||
}
|
||||
.drawer-panel {
|
||||
width: 100%; max-height: 70vh; background: var(--bg-primary);
|
||||
border-radius: 16px 16px 0 0; overflow-y: auto;
|
||||
animation: drawer-slide-up 0.25s ease-out;
|
||||
}
|
||||
.drawer-header {
|
||||
display: flex; justify-content: space-between; align-items: center;
|
||||
padding: 1rem; border-bottom: 1px solid var(--border); position: sticky; top: 0;
|
||||
background: var(--bg-primary);
|
||||
}
|
||||
.drawer-body { padding: 1rem; }
|
||||
.drawer-panel .consensus-panel {
|
||||
width: 100%; border-left: none; padding: 0;
|
||||
}
|
||||
@keyframes drawer-slide-up {
|
||||
from { transform: translateY(100%); }
|
||||
to { transform: translateY(0); }
|
||||
}
|
||||
|
||||
/* ========== Error Boundary ========== */
|
||||
.error-boundary-fallback {
|
||||
flex: 1; display: flex; flex-direction: column;
|
||||
align-items: center; justify-content: center;
|
||||
padding: 2rem; text-align: center; gap: 1rem;
|
||||
}
|
||||
|
||||
/* ========== Page Transitions ========== */
|
||||
.main > * {
|
||||
animation: page-fade-in 0.2s ease-out;
|
||||
}
|
||||
@keyframes page-fade-in {
|
||||
from { opacity: 0; transform: translateY(4px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
/* ========== Search Filter ========== */
|
||||
.search-filter {
|
||||
width: 100%; padding: 0.5rem 0.75rem; margin-bottom: 1rem;
|
||||
background: var(--input-bg); border: 1px solid var(--border);
|
||||
border-radius: 6px; color: var(--text-primary); font-size: 0.85rem;
|
||||
}
|
||||
.search-filter:focus { border-color: var(--accent); outline: none; }
|
||||
|
||||
/* ========== Screen Reader Only ========== */
|
||||
.sr-only {
|
||||
position: absolute; width: 1px; height: 1px;
|
||||
padding: 0; margin: -1px; overflow: hidden;
|
||||
clip: rect(0, 0, 0, 0); white-space: nowrap; border: 0;
|
||||
}
|
||||
|
||||
/* ========== Responsive ========== */
|
||||
@media (max-width: 768px) {
|
||||
.header { flex-direction: column; gap: 0.5rem; padding: 0.5rem 1rem; }
|
||||
|
||||
@@ -1,46 +1,71 @@
|
||||
import { useState, useCallback, useEffect, useRef } from 'react'
|
||||
import { useState, useCallback, useEffect, useRef, lazy, Suspense } from 'react'
|
||||
import { AvatarGrid } from './components/AvatarGrid'
|
||||
import { ConsensusPanel } from './components/ConsensusPanel'
|
||||
import { ChatMessage } from './components/ChatMessage'
|
||||
import { VirtualMessages } from './components/VirtualMessages'
|
||||
import { ToastProvider, useToast } from './components/Toast'
|
||||
import { AdminPage } from './pages/AdminPage'
|
||||
import { EthicsPage } from './pages/EthicsPage'
|
||||
import { SettingsPage } from './pages/SettingsPage'
|
||||
import { ErrorBoundary } from './components/ErrorBoundary'
|
||||
import { MobileDrawer } from './components/MobileDrawer'
|
||||
import { SkeletonGrid } from './components/Skeleton'
|
||||
import { LoginPage } from './pages/LoginPage'
|
||||
import { useTheme } from './hooks/useTheme'
|
||||
import { useAuth } from './hooks/useAuth'
|
||||
import { useWebSocket } from './hooks/useWebSocket'
|
||||
import { useVoicePlayback } from './hooks/useVoicePlayback'
|
||||
import { useKeyboard } from './hooks/useKeyboard'
|
||||
import { useChatHistory } from './hooks/useChatHistory'
|
||||
import type { FinalResponse, Page, ViewMode, WSEvent } from './types'
|
||||
import './App.css'
|
||||
|
||||
const AdminPage = lazy(() => import('./pages/AdminPage').then((m) => ({ default: m.AdminPage })))
|
||||
const EthicsPage = lazy(() => import('./pages/EthicsPage').then((m) => ({ default: m.EthicsPage })))
|
||||
const SettingsPage = lazy(() => import('./pages/SettingsPage').then((m) => ({ default: m.SettingsPage })))
|
||||
|
||||
const HEAD_IDS = [
|
||||
'logic', 'research', 'systems', 'strategy', 'product',
|
||||
'security', 'safety', 'reliability', 'cost', 'data', 'devex', 'witness',
|
||||
]
|
||||
|
||||
function PageSkeleton() {
|
||||
return (
|
||||
<div className="admin-page" role="status" aria-label="Loading page">
|
||||
<SkeletonGrid count={6} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function App() {
|
||||
const { theme, toggle: toggleTheme } = useTheme()
|
||||
const { token, error: authError, setError: setAuthError, login, logout, authHeaders, isAuthenticated } = useAuth()
|
||||
const { toast } = useToast()
|
||||
const { token, error: authError, login, logout, authHeaders, isAuthenticated } = useAuth()
|
||||
const [page, setPage] = useState<Page>('chat')
|
||||
const [sessionId, setSessionId] = useState<string | null>(null)
|
||||
const [prompt, setPrompt] = useState('')
|
||||
const [messages, setMessages] = useState<{ role: 'user' | 'assistant'; content: string; data?: FinalResponse }[]>([])
|
||||
const { messages, addMessage, editMessage, deleteMessage, clearHistory, setMessages } = useChatHistory()
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [activeHeads, setActiveHeads] = useState<string[]>([])
|
||||
const [viewMode, setViewMode] = useState<ViewMode>('normal')
|
||||
const [lastResponse, setLastResponse] = useState<FinalResponse | null>(null)
|
||||
const [networkError, setNetworkError] = useState<string | null>(null)
|
||||
const [useStreaming, setUseStreaming] = useState(false)
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null)
|
||||
const [isMobile, setIsMobile] = useState(false)
|
||||
const inputRef = useRef<HTMLInputElement>(null)
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
const { speakingHead, headSummaries, onHeadSpeak, clearSpeaking } = useVoicePlayback()
|
||||
const ws = useWebSocket(sessionId)
|
||||
|
||||
useEffect(() => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
|
||||
}, [messages])
|
||||
const check = () => setIsMobile(window.innerWidth <= 768)
|
||||
check()
|
||||
window.addEventListener('resize', check)
|
||||
return () => window.removeEventListener('resize', check)
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if ('serviceWorker' in navigator) {
|
||||
navigator.serviceWorker.register('/sw.js').catch(() => {})
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Handle WS events
|
||||
useEffect(() => {
|
||||
if (ws.events.length === 0) return
|
||||
const last = ws.events[ws.events.length - 1]
|
||||
@@ -53,14 +78,10 @@ function App() {
|
||||
setActiveHeads(HEAD_IDS.slice(0, 6))
|
||||
break
|
||||
case 'head_complete':
|
||||
if (event.head_id && event.summary) {
|
||||
onHeadSpeak(event.head_id, event.summary, null)
|
||||
}
|
||||
if (event.head_id && event.summary) onHeadSpeak(event.head_id, event.summary, null)
|
||||
break
|
||||
case 'head_speak':
|
||||
if (event.head_id && event.summary) {
|
||||
onHeadSpeak(event.head_id, event.summary, event.audio_base64)
|
||||
}
|
||||
if (event.head_id && event.summary) onHeadSpeak(event.head_id, event.summary, event.audio_base64)
|
||||
break
|
||||
case 'witness_running':
|
||||
clearSpeaking()
|
||||
@@ -74,13 +95,13 @@ function App() {
|
||||
confidence_score: event.confidence_score || 0,
|
||||
}
|
||||
setLastResponse(resp)
|
||||
setMessages((m) => [...m, { role: 'assistant', content: event.final_answer!, data: resp }])
|
||||
addMessage('assistant', event.final_answer!, resp)
|
||||
}
|
||||
setLoading(false)
|
||||
setActiveHeads([])
|
||||
break
|
||||
case 'error':
|
||||
setMessages((m) => [...m, { role: 'assistant', content: `Error: ${event.message}` }])
|
||||
addMessage('assistant', `Error: ${event.message}`)
|
||||
setLoading(false)
|
||||
setActiveHeads([])
|
||||
break
|
||||
@@ -114,7 +135,7 @@ function App() {
|
||||
const sid = await ensureSession()
|
||||
if (!sid) return
|
||||
|
||||
setMessages((m) => [...m, { role: 'user', content: prompt }])
|
||||
addMessage('user', prompt)
|
||||
const currentPrompt = prompt
|
||||
setPrompt('')
|
||||
setLoading(true)
|
||||
@@ -141,30 +162,73 @@ function App() {
|
||||
const contribs = data.head_contributions || []
|
||||
contribs.forEach((c: { head_id: string; summary: string }) =>
|
||||
onHeadSpeak(c.head_id, c.summary, null))
|
||||
setMessages((m) => [...m, { role: 'assistant', content: data.final_answer, data }])
|
||||
addMessage('assistant', data.final_answer, data)
|
||||
setNetworkError(null)
|
||||
} catch (e) {
|
||||
const msg = (e as Error).message
|
||||
setNetworkError(msg)
|
||||
setMessages((m) => [...m, { role: 'assistant', content: `Error: ${msg}` }])
|
||||
addMessage('assistant', `Error: ${msg}`)
|
||||
} finally {
|
||||
setLoading(false)
|
||||
setActiveHeads([])
|
||||
}
|
||||
}
|
||||
}, [prompt, loading, ensureSession, useStreaming, ws, authHeaders, parseJson, clearSpeaking, onHeadSpeak])
|
||||
}, [prompt, loading, ensureSession, useStreaming, ws, authHeaders, parseJson, clearSpeaking, onHeadSpeak, addMessage])
|
||||
|
||||
const handleRetry = () => {
|
||||
if (messages.length >= 2) {
|
||||
const lastUser = [...messages].reverse().find((m) => m.role === 'user')
|
||||
if (lastUser) {
|
||||
setPrompt(lastUser.content)
|
||||
setNetworkError(null)
|
||||
}
|
||||
const lastUser = [...messages].reverse().find((m) => m.role === 'user')
|
||||
if (lastUser) {
|
||||
setPrompt(lastUser.content)
|
||||
setNetworkError(null)
|
||||
}
|
||||
}
|
||||
|
||||
// Login screen
|
||||
const handleEditMessage = useCallback((index: number) => {
|
||||
const msg = messages[index]
|
||||
if (msg?.role === 'user') {
|
||||
setPrompt(msg.content)
|
||||
toast('Message loaded for editing', 'info')
|
||||
}
|
||||
}, [messages, toast])
|
||||
|
||||
const handleDeleteMessage = useCallback((index: number) => {
|
||||
deleteMessage(index)
|
||||
toast('Message deleted', 'info')
|
||||
}, [deleteMessage, toast])
|
||||
|
||||
const handleFileUpload = useCallback(async (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = e.target.files?.[0]
|
||||
if (!file) return
|
||||
if (file.size > 10 * 1024 * 1024) {
|
||||
toast('File too large (max 10MB)', 'error')
|
||||
return
|
||||
}
|
||||
const text = await file.text()
|
||||
setPrompt((p) => p + (p ? '\n' : '') + `[File: ${file.name}]\n${text.slice(0, 5000)}`)
|
||||
toast(`Attached: ${file.name}`, 'success')
|
||||
e.target.value = ''
|
||||
}, [toast])
|
||||
|
||||
const syncPreferences = useCallback(async () => {
|
||||
try {
|
||||
const r = await fetch('/v1/admin/conversation-style', { headers: authHeaders() })
|
||||
if (r.ok) {
|
||||
toast('Preferences synced', 'success')
|
||||
}
|
||||
} catch { /* offline */ }
|
||||
}, [authHeaders, toast])
|
||||
|
||||
useEffect(() => {
|
||||
if (isAuthenticated) syncPreferences()
|
||||
}, [isAuthenticated])
|
||||
|
||||
useKeyboard({
|
||||
onSend: handleSubmit,
|
||||
onSearch: () => inputRef.current?.focus(),
|
||||
onDismiss: () => setNetworkError(null),
|
||||
onToggleTheme: toggleTheme,
|
||||
})
|
||||
|
||||
if (!isAuthenticated && !token && token !== '') {
|
||||
return <LoginPage onLogin={login} error={authError} />
|
||||
}
|
||||
@@ -220,43 +284,58 @@ function App() {
|
||||
speakingHead={speakingHead}
|
||||
headSummaries={headSummaries}
|
||||
/>
|
||||
<div className="messages" role="log" aria-label="Conversation" aria-live="polite">
|
||||
{messages.length === 0 && (
|
||||
{messages.length === 0 ? (
|
||||
<div className="messages">
|
||||
<div className="empty-state">
|
||||
<h2>Welcome to FusionAGI Dvādaśa</h2>
|
||||
<p>12 specialized heads analyze your query from every angle. Ask anything.</p>
|
||||
<div className="suggestions">
|
||||
{['Explain quantum entanglement', 'Design a microservice architecture', 'Analyze the ethics of AI autonomy'].map((s) => (
|
||||
<button key={s} className="suggestion" onClick={() => { setPrompt(s); }}>
|
||||
<button key={s} className="suggestion" onClick={() => setPrompt(s)}>
|
||||
{s}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{messages.map((msg, i) => (
|
||||
<ChatMessage key={i} message={msg} viewMode={viewMode} />
|
||||
))}
|
||||
{loading && (
|
||||
<div className="loading-indicator" role="status" aria-live="assertive">
|
||||
<div className="loading-dots" aria-hidden="true"><span /><span /><span /></div>
|
||||
<span>Heads analyzing...</span>
|
||||
</div>
|
||||
)}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<VirtualMessages
|
||||
messages={messages}
|
||||
viewMode={viewMode}
|
||||
loading={loading}
|
||||
onEditMessage={handleEditMessage}
|
||||
onDeleteMessage={handleDeleteMessage}
|
||||
/>
|
||||
)}
|
||||
<div className="input-area">
|
||||
<div className="input-row">
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="text"
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
onKeyDown={(e) => e.key === 'Enter' && !e.shiftKey && handleSubmit()}
|
||||
placeholder="Ask FusionAGI... (/head strategy, /show dissent)"
|
||||
placeholder="Ask FusionAGI... (Ctrl+Enter to send, Ctrl+K to focus)"
|
||||
autoComplete="off"
|
||||
disabled={loading}
|
||||
aria-label="Message input"
|
||||
/>
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
className="sr-only"
|
||||
onChange={handleFileUpload}
|
||||
accept=".txt,.md,.json,.csv,.py,.js,.ts,.tsx"
|
||||
aria-label="Attach file"
|
||||
/>
|
||||
<button
|
||||
className="icon-btn"
|
||||
onClick={() => fileInputRef.current?.click()}
|
||||
title="Attach file"
|
||||
aria-label="Attach file"
|
||||
>
|
||||
+
|
||||
</button>
|
||||
<button onClick={handleSubmit} disabled={loading || !prompt.trim()} className="send-btn" aria-label="Send message">
|
||||
Send
|
||||
</button>
|
||||
@@ -266,16 +345,30 @@ function App() {
|
||||
<input type="checkbox" checked={useStreaming} onChange={(e) => setUseStreaming(e.target.checked)} />
|
||||
<span>Stream</span>
|
||||
</label>
|
||||
{messages.length > 0 && (
|
||||
<button className="clear-history-btn" onClick={() => { clearHistory(); toast('Chat history cleared', 'info') }}>
|
||||
Clear
|
||||
</button>
|
||||
)}
|
||||
{sessionId && <span className="session-id">Session: {sessionId.slice(0, 8)}...</span>}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<ConsensusPanel response={lastResponse} viewMode={viewMode} expanded={viewMode !== 'normal'} />
|
||||
{!isMobile && <ConsensusPanel response={lastResponse} viewMode={viewMode} expanded={viewMode !== 'normal'} />}
|
||||
{isMobile && lastResponse && (
|
||||
<MobileDrawer title="Consensus" visible={viewMode !== 'normal'}>
|
||||
<ConsensusPanel response={lastResponse} viewMode={viewMode} expanded={true} />
|
||||
</MobileDrawer>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{page === 'admin' && <AdminPage authHeaders={authHeaders} />}
|
||||
{page === 'ethics' && <EthicsPage authHeaders={authHeaders} />}
|
||||
{page === 'settings' && <SettingsPage theme={theme} toggleTheme={toggleTheme} authHeaders={authHeaders} />}
|
||||
<Suspense fallback={<PageSkeleton />}>
|
||||
<ErrorBoundary>
|
||||
{page === 'admin' && <AdminPage authHeaders={authHeaders} />}
|
||||
{page === 'ethics' && <EthicsPage authHeaders={authHeaders} />}
|
||||
{page === 'settings' && <SettingsPage theme={theme} toggleTheme={toggleTheme} authHeaders={authHeaders} />}
|
||||
</ErrorBoundary>
|
||||
</Suspense>
|
||||
</main>
|
||||
</div>
|
||||
)
|
||||
|
||||
86
frontend/src/components/AccessibilityChecker.tsx
Normal file
86
frontend/src/components/AccessibilityChecker.tsx
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Accessibility audit utility.
|
||||
*
|
||||
* Provides automated a11y checks that can be integrated into CI
|
||||
* or run manually during development. Uses DOM queries to verify
|
||||
* WCAG compliance of rendered components.
|
||||
*/
|
||||
|
||||
export interface A11yViolation {
|
||||
rule: string
|
||||
element: string
|
||||
description: string
|
||||
severity: 'critical' | 'serious' | 'moderate' | 'minor'
|
||||
}
|
||||
|
||||
export function auditAccessibility(root: HTMLElement = document.body): A11yViolation[] {
|
||||
const violations: A11yViolation[] = []
|
||||
|
||||
// Check images without alt text
|
||||
root.querySelectorAll('img:not([alt])').forEach((el) => {
|
||||
violations.push({
|
||||
rule: 'img-alt',
|
||||
element: el.outerHTML.slice(0, 80),
|
||||
description: 'Image missing alt attribute',
|
||||
severity: 'critical',
|
||||
})
|
||||
})
|
||||
|
||||
// Check buttons without accessible name
|
||||
root.querySelectorAll('button').forEach((el) => {
|
||||
const name = el.textContent?.trim() || el.getAttribute('aria-label') || el.getAttribute('title')
|
||||
if (!name) {
|
||||
violations.push({
|
||||
rule: 'button-name',
|
||||
element: el.outerHTML.slice(0, 80),
|
||||
description: 'Button has no accessible name',
|
||||
severity: 'serious',
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Check inputs without labels
|
||||
root.querySelectorAll('input:not([type="hidden"])').forEach((el) => {
|
||||
const id = el.getAttribute('id')
|
||||
const ariaLabel = el.getAttribute('aria-label') || el.getAttribute('aria-labelledby')
|
||||
const hasLabel = id ? root.querySelector(`label[for="${id}"]`) : false
|
||||
if (!ariaLabel && !hasLabel && !el.getAttribute('title')) {
|
||||
violations.push({
|
||||
rule: 'input-label',
|
||||
element: el.outerHTML.slice(0, 80),
|
||||
description: 'Input has no associated label',
|
||||
severity: 'serious',
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Check contrast (basic check for known problem patterns)
|
||||
root.querySelectorAll('[style*="color"]').forEach((el) => {
|
||||
const style = window.getComputedStyle(el as Element)
|
||||
const color = style.color
|
||||
const bg = style.backgroundColor
|
||||
if (color === bg && color !== 'rgba(0, 0, 0, 0)') {
|
||||
violations.push({
|
||||
rule: 'color-contrast',
|
||||
element: (el as Element).outerHTML.slice(0, 80),
|
||||
description: 'Text and background colors are identical',
|
||||
severity: 'critical',
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Check for tabindex > 0
|
||||
root.querySelectorAll('[tabindex]').forEach((el) => {
|
||||
const idx = parseInt(el.getAttribute('tabindex') || '0', 10)
|
||||
if (idx > 0) {
|
||||
violations.push({
|
||||
rule: 'tabindex',
|
||||
element: el.outerHTML.slice(0, 80),
|
||||
description: 'Positive tabindex disrupts natural tab order',
|
||||
severity: 'moderate',
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return violations
|
||||
}
|
||||
21
frontend/src/components/Avatar.stories.tsx
Normal file
21
frontend/src/components/Avatar.stories.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { Meta, StoryObj } from '@storybook/react'
|
||||
import { Avatar } from './Avatar'
|
||||
|
||||
const meta: Meta<typeof Avatar> = {
|
||||
title: 'Components/Avatar',
|
||||
component: Avatar,
|
||||
argTypes: {
|
||||
headId: {
|
||||
control: 'select',
|
||||
options: ['logic', 'research', 'systems', 'strategy', 'product', 'security', 'safety', 'reliability', 'cost', 'data', 'devex', 'witness'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
export default meta
|
||||
type Story = StoryObj<typeof Avatar>
|
||||
|
||||
export const Idle: Story = { args: { headId: 'logic' } }
|
||||
export const Active: Story = { args: { headId: 'research', isActive: true } }
|
||||
export const Speaking: Story = { args: { headId: 'strategy', isSpeaking: true } }
|
||||
export const WithSummary: Story = { args: { headId: 'security', isActive: true, summary: 'Analyzing threat vectors' } }
|
||||
36
frontend/src/components/Avatar.test.tsx
Normal file
36
frontend/src/components/Avatar.test.tsx
Normal file
@@ -0,0 +1,36 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { Avatar } from './Avatar'
|
||||
|
||||
describe('Avatar', () => {
|
||||
it('renders head name', () => {
|
||||
render(<Avatar headId="logic" />)
|
||||
expect(screen.getByText('Logic')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('shows 2-letter placeholder', () => {
|
||||
const { container } = render(<Avatar headId="research" />)
|
||||
expect(container.querySelector('.avatar-placeholder')?.textContent).toBe('re')
|
||||
})
|
||||
|
||||
it('applies active class when active', () => {
|
||||
const { container } = render(<Avatar headId="logic" isActive={true} />)
|
||||
expect(container.querySelector('.avatar.active')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('applies speaking class when speaking', () => {
|
||||
const { container } = render(<Avatar headId="logic" isSpeaking={true} />)
|
||||
expect(container.querySelector('.avatar.speaking')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('has data-head attribute', () => {
|
||||
const { container } = render(<Avatar headId="strategy" />)
|
||||
expect(container.querySelector('[data-head="strategy"]')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('has aria-label with status', () => {
|
||||
render(<Avatar headId="logic" isActive={true} />)
|
||||
const el = screen.getByRole('status')
|
||||
expect(el.getAttribute('aria-label')).toContain('active')
|
||||
})
|
||||
})
|
||||
38
frontend/src/components/ChatMessage.test.tsx
Normal file
38
frontend/src/components/ChatMessage.test.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { ChatMessage } from './ChatMessage'
|
||||
|
||||
describe('ChatMessage', () => {
|
||||
it('renders user message', () => {
|
||||
render(<ChatMessage message={{ role: 'user', content: 'Hello' }} viewMode="normal" />)
|
||||
expect(screen.getByText('Hello')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('renders assistant message with markdown', () => {
|
||||
render(<ChatMessage message={{ role: 'assistant', content: '**Bold response**' }} viewMode="normal" />)
|
||||
expect(screen.getByText('Bold response')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('shows head contributions in explain mode', () => {
|
||||
const data = {
|
||||
final_answer: 'Answer',
|
||||
transparency_report: { head_contributions: [], agreement_map: { agreed_claims: [], disputed_claims: [], confidence_score: 0.9 }, safety_report: '', confidence_score: 0.9 },
|
||||
head_contributions: [{ head_id: 'logic', summary: 'Logical analysis' }],
|
||||
confidence_score: 0.9,
|
||||
}
|
||||
render(<ChatMessage message={{ role: 'assistant', content: 'Answer', data }} viewMode="explain" />)
|
||||
expect(screen.getByText('logic')).toBeTruthy()
|
||||
expect(screen.getByText('Logical analysis')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('hides head contributions in normal mode', () => {
|
||||
const data = {
|
||||
final_answer: 'Answer',
|
||||
transparency_report: { head_contributions: [], agreement_map: { agreed_claims: [], disputed_claims: [], confidence_score: 0.9 }, safety_report: '', confidence_score: 0.9 },
|
||||
head_contributions: [{ head_id: 'logic', summary: 'Logical analysis' }],
|
||||
confidence_score: 0.9,
|
||||
}
|
||||
render(<ChatMessage message={{ role: 'assistant', content: 'Answer', data }} viewMode="normal" />)
|
||||
expect(screen.queryByText('logic')).toBeNull()
|
||||
})
|
||||
})
|
||||
@@ -1,9 +1,12 @@
|
||||
import { useState } from 'react'
|
||||
import type { FinalResponse } from '../types'
|
||||
import { Markdown } from './Markdown'
|
||||
|
||||
interface ChatMessageProps {
|
||||
message: { role: 'user' | 'assistant'; content: string; data?: FinalResponse }
|
||||
viewMode: string
|
||||
onEdit?: () => void
|
||||
onDelete?: () => void
|
||||
}
|
||||
|
||||
function extractSynthesis(content: string): string {
|
||||
@@ -18,13 +21,26 @@ function extractSynthesis(content: string): string {
|
||||
return filtered.join('\n').trim()
|
||||
}
|
||||
|
||||
export function ChatMessage({ message, viewMode }: ChatMessageProps) {
|
||||
export function ChatMessage({ message, viewMode, onEdit, onDelete }: ChatMessageProps) {
|
||||
const isUser = message.role === 'user'
|
||||
const [showActions, setShowActions] = useState(false)
|
||||
|
||||
if (isUser) {
|
||||
return (
|
||||
<div className="message user" role="log" aria-label="Your message">
|
||||
<div
|
||||
className="message user"
|
||||
role="log"
|
||||
aria-label="Your message"
|
||||
onMouseEnter={() => setShowActions(true)}
|
||||
onMouseLeave={() => setShowActions(false)}
|
||||
>
|
||||
<div className="message-content">{message.content}</div>
|
||||
{showActions && (onEdit || onDelete) && (
|
||||
<div className="message-actions">
|
||||
{onEdit && <button className="msg-action-btn" onClick={onEdit} aria-label="Edit message">Edit</button>}
|
||||
{onDelete && <button className="msg-action-btn" onClick={onDelete} aria-label="Delete message">Del</button>}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -33,7 +49,13 @@ export function ChatMessage({ message, viewMode }: ChatMessageProps) {
|
||||
const synthesis = extractSynthesis(message.content)
|
||||
|
||||
return (
|
||||
<div className="message assistant" role="log" aria-label="FusionAGI response">
|
||||
<div
|
||||
className="message assistant"
|
||||
role="log"
|
||||
aria-label="FusionAGI response"
|
||||
onMouseEnter={() => setShowActions(true)}
|
||||
onMouseLeave={() => setShowActions(false)}
|
||||
>
|
||||
<div className="response-structured">
|
||||
<Markdown content={synthesis} />
|
||||
{hasHeadData && (viewMode === 'explain' || viewMode === 'developer') && (
|
||||
@@ -57,6 +79,11 @@ export function ChatMessage({ message, viewMode }: ChatMessageProps) {
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{showActions && onDelete && (
|
||||
<div className="message-actions">
|
||||
<button className="msg-action-btn" onClick={onDelete} aria-label="Delete message">Del</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
41
frontend/src/components/ErrorBoundary.test.tsx
Normal file
41
frontend/src/components/ErrorBoundary.test.tsx
Normal file
@@ -0,0 +1,41 @@
|
||||
import { describe, it, expect, vi } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { ErrorBoundary } from './ErrorBoundary'
|
||||
|
||||
function ThrowingComponent() {
|
||||
throw new Error('Test error')
|
||||
}
|
||||
|
||||
describe('ErrorBoundary', () => {
|
||||
it('catches errors and shows fallback', () => {
|
||||
const spy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
render(
|
||||
<ErrorBoundary>
|
||||
<ThrowingComponent />
|
||||
</ErrorBoundary>
|
||||
)
|
||||
expect(screen.getByText('Something went wrong')).toBeTruthy()
|
||||
expect(screen.getByText('Test error')).toBeTruthy()
|
||||
spy.mockRestore()
|
||||
})
|
||||
|
||||
it('renders children when no error', () => {
|
||||
render(
|
||||
<ErrorBoundary>
|
||||
<div>Working fine</div>
|
||||
</ErrorBoundary>
|
||||
)
|
||||
expect(screen.getByText('Working fine')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('shows custom fallback', () => {
|
||||
const spy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
render(
|
||||
<ErrorBoundary fallback={<div>Custom fallback</div>}>
|
||||
<ThrowingComponent />
|
||||
</ErrorBoundary>
|
||||
)
|
||||
expect(screen.getByText('Custom fallback')).toBeTruthy()
|
||||
spy.mockRestore()
|
||||
})
|
||||
})
|
||||
48
frontend/src/components/ErrorBoundary.tsx
Normal file
48
frontend/src/components/ErrorBoundary.tsx
Normal file
@@ -0,0 +1,48 @@
|
||||
import { Component } from 'react'
|
||||
import type { ReactNode, ErrorInfo } from 'react'
|
||||
|
||||
interface Props {
|
||||
children: ReactNode
|
||||
fallback?: ReactNode
|
||||
onError?: (error: Error, info: ErrorInfo) => void
|
||||
}
|
||||
|
||||
interface State {
|
||||
hasError: boolean
|
||||
error: Error | null
|
||||
}
|
||||
|
||||
export class ErrorBoundary extends Component<Props, State> {
|
||||
constructor(props: Props) {
|
||||
super(props)
|
||||
this.state = { hasError: false, error: null }
|
||||
}
|
||||
|
||||
static getDerivedStateFromError(error: Error): State {
|
||||
return { hasError: true, error }
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, info: ErrorInfo) {
|
||||
console.error('ErrorBoundary caught:', error, info)
|
||||
this.props.onError?.(error, info)
|
||||
}
|
||||
|
||||
render() {
|
||||
if (this.state.hasError) {
|
||||
if (this.props.fallback) return this.props.fallback
|
||||
return (
|
||||
<div className="error-boundary-fallback" role="alert">
|
||||
<h3>Something went wrong</h3>
|
||||
<p className="muted">{this.state.error?.message || 'An unexpected error occurred'}</p>
|
||||
<button
|
||||
className="theme-toggle"
|
||||
onClick={() => this.setState({ hasError: false, error: null })}
|
||||
>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return this.props.children
|
||||
}
|
||||
}
|
||||
44
frontend/src/components/Markdown.test.tsx
Normal file
44
frontend/src/components/Markdown.test.tsx
Normal file
@@ -0,0 +1,44 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { Markdown } from './Markdown'
|
||||
|
||||
describe('Markdown', () => {
|
||||
it('renders paragraphs', () => {
|
||||
render(<Markdown content="Hello world" />)
|
||||
expect(screen.getByText('Hello world')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('renders bold text', () => {
|
||||
const { container } = render(<Markdown content="**bold text**" />)
|
||||
expect(container.querySelector('strong')?.textContent).toBe('bold text')
|
||||
})
|
||||
|
||||
it('renders inline code', () => {
|
||||
const { container } = render(<Markdown content="Use `console.log`" />)
|
||||
expect(container.querySelector('code')?.textContent).toBe('console.log')
|
||||
})
|
||||
|
||||
it('renders unordered lists', () => {
|
||||
const { container } = render(<Markdown content={'- item one\n- item two'} />)
|
||||
const items = container.querySelectorAll('li')
|
||||
expect(items.length).toBe(2)
|
||||
})
|
||||
|
||||
it('renders headings', () => {
|
||||
const { container } = render(<Markdown content="# Title" />)
|
||||
expect(container.querySelector('h1')?.textContent).toBe('Title')
|
||||
})
|
||||
|
||||
it('renders code blocks with copy button', () => {
|
||||
const { container } = render(<Markdown content="```js\nconsole.log('hi')\n```" />)
|
||||
expect(container.querySelector('.copy-code-btn')).toBeTruthy()
|
||||
expect(container.querySelector('pre')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('renders links', () => {
|
||||
const { container } = render(<Markdown content="[Click](https://example.com)" />)
|
||||
const a = container.querySelector('a')
|
||||
expect(a?.getAttribute('href')).toBe('https://example.com')
|
||||
expect(a?.getAttribute('target')).toBe('_blank')
|
||||
})
|
||||
})
|
||||
@@ -1,3 +1,5 @@
|
||||
import { useCallback, useRef, useEffect } from 'react'
|
||||
|
||||
function escapeHtml(text: string): string {
|
||||
return text.replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>')
|
||||
}
|
||||
@@ -16,17 +18,21 @@ function parseMarkdown(md: string): string {
|
||||
const html: string[] = []
|
||||
let inCode = false
|
||||
let codeBlock: string[] = []
|
||||
let codeLang = ''
|
||||
let inList = false
|
||||
let listType: 'ul' | 'ol' = 'ul'
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('```')) {
|
||||
if (inCode) {
|
||||
html.push(`<pre><code>${escapeHtml(codeBlock.join('\n'))}</code></pre>`)
|
||||
const escaped = escapeHtml(codeBlock.join('\n'))
|
||||
html.push(`<div class="code-block-wrapper"><button class="copy-code-btn" data-code="${encodeURIComponent(codeBlock.join('\n'))}">Copy</button><pre><code class="lang-${codeLang}">${escaped}</code></pre></div>`)
|
||||
codeBlock = []
|
||||
codeLang = ''
|
||||
inCode = false
|
||||
} else {
|
||||
if (inList) { html.push(`</${listType}>`); inList = false }
|
||||
codeLang = line.slice(3).trim()
|
||||
inCode = true
|
||||
}
|
||||
continue
|
||||
@@ -68,14 +74,40 @@ function parseMarkdown(md: string): string {
|
||||
html.push(`<p>${renderInline(trimmed)}</p>`)
|
||||
}
|
||||
}
|
||||
if (inCode) html.push(`<pre><code>${escapeHtml(codeBlock.join('\n'))}</code></pre>`)
|
||||
if (inCode) {
|
||||
const escaped = escapeHtml(codeBlock.join('\n'))
|
||||
html.push(`<div class="code-block-wrapper"><button class="copy-code-btn" data-code="${encodeURIComponent(codeBlock.join('\n'))}">Copy</button><pre><code>${escaped}</code></pre></div>`)
|
||||
}
|
||||
if (inList) html.push(`</${listType}>`)
|
||||
return html.join('')
|
||||
}
|
||||
|
||||
export function Markdown({ content }: { content: string }) {
|
||||
const ref = useRef<HTMLDivElement>(null)
|
||||
|
||||
const handleClick = useCallback((e: MouseEvent) => {
|
||||
const btn = (e.target as HTMLElement).closest('.copy-code-btn') as HTMLButtonElement | null
|
||||
if (!btn) return
|
||||
const code = decodeURIComponent(btn.dataset.code || '')
|
||||
navigator.clipboard.writeText(code).then(() => {
|
||||
btn.textContent = 'Copied!'
|
||||
setTimeout(() => { btn.textContent = 'Copy' }, 2000)
|
||||
}).catch(() => {
|
||||
btn.textContent = 'Failed'
|
||||
setTimeout(() => { btn.textContent = 'Copy' }, 2000)
|
||||
})
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const el = ref.current
|
||||
if (!el) return
|
||||
el.addEventListener('click', handleClick as EventListener)
|
||||
return () => el.removeEventListener('click', handleClick as EventListener)
|
||||
}, [handleClick])
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className="response-synthesis"
|
||||
dangerouslySetInnerHTML={{ __html: parseMarkdown(content) }}
|
||||
/>
|
||||
|
||||
44
frontend/src/components/MobileDrawer.tsx
Normal file
44
frontend/src/components/MobileDrawer.tsx
Normal file
@@ -0,0 +1,44 @@
|
||||
import { useState } from 'react'
|
||||
import type { ReactNode } from 'react'
|
||||
|
||||
interface MobileDrawerProps {
|
||||
children: ReactNode
|
||||
title: string
|
||||
visible: boolean
|
||||
}
|
||||
|
||||
export function MobileDrawer({ children, title, visible }: MobileDrawerProps) {
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
if (!visible) return null
|
||||
|
||||
return (
|
||||
<>
|
||||
<button
|
||||
className="drawer-trigger"
|
||||
onClick={() => setOpen(true)}
|
||||
aria-label={`Open ${title}`}
|
||||
>
|
||||
{title}
|
||||
</button>
|
||||
{open && (
|
||||
<div className="drawer-overlay" onClick={() => setOpen(false)}>
|
||||
<div
|
||||
className="drawer-panel"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
role="dialog"
|
||||
aria-label={title}
|
||||
>
|
||||
<div className="drawer-header">
|
||||
<h3>{title}</h3>
|
||||
<button className="icon-btn" onClick={() => setOpen(false)} aria-label="Close">X</button>
|
||||
</div>
|
||||
<div className="drawer-body">
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
29
frontend/src/components/SearchFilter.tsx
Normal file
29
frontend/src/components/SearchFilter.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
import { useState, useEffect, useRef } from 'react'
|
||||
|
||||
interface SearchFilterProps {
|
||||
placeholder?: string
|
||||
onFilter: (query: string) => void
|
||||
debounceMs?: number
|
||||
}
|
||||
|
||||
export function SearchFilter({ placeholder = 'Search...', onFilter, debounceMs = 300 }: SearchFilterProps) {
|
||||
const [value, setValue] = useState('')
|
||||
const timer = useRef<ReturnType<typeof setTimeout> | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (timer.current) clearTimeout(timer.current)
|
||||
timer.current = setTimeout(() => onFilter(value), debounceMs)
|
||||
return () => { if (timer.current) clearTimeout(timer.current) }
|
||||
}, [value, debounceMs, onFilter])
|
||||
|
||||
return (
|
||||
<input
|
||||
type="search"
|
||||
className="search-filter"
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
placeholder={placeholder}
|
||||
aria-label={placeholder}
|
||||
/>
|
||||
)
|
||||
}
|
||||
20
frontend/src/components/Skeleton.test.tsx
Normal file
20
frontend/src/components/Skeleton.test.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { render } from '@testing-library/react'
|
||||
import { Skeleton, SkeletonCard, SkeletonGrid } from './Skeleton'
|
||||
|
||||
describe('Skeleton', () => {
|
||||
it('renders specified count of skeleton lines', () => {
|
||||
const { container } = render(<Skeleton count={3} />)
|
||||
expect(container.querySelectorAll('.skeleton').length).toBe(3)
|
||||
})
|
||||
|
||||
it('renders skeleton card', () => {
|
||||
const { container } = render(<SkeletonCard />)
|
||||
expect(container.querySelector('.skeleton-card')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('renders skeleton grid with count', () => {
|
||||
const { container } = render(<SkeletonGrid count={4} />)
|
||||
expect(container.querySelectorAll('.skeleton-card').length).toBe(4)
|
||||
})
|
||||
})
|
||||
45
frontend/src/components/Skeleton.tsx
Normal file
45
frontend/src/components/Skeleton.tsx
Normal file
@@ -0,0 +1,45 @@
|
||||
interface SkeletonProps {
|
||||
width?: string
|
||||
height?: string
|
||||
count?: number
|
||||
className?: string
|
||||
}
|
||||
|
||||
function SkeletonLine({ width, height, className }: SkeletonProps) {
|
||||
return (
|
||||
<div
|
||||
className={`skeleton ${className || ''}`}
|
||||
style={{ width: width || '100%', height: height || '1rem' }}
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export function Skeleton({ width, height, count = 1, className }: SkeletonProps) {
|
||||
return (
|
||||
<>
|
||||
{Array.from({ length: count }, (_, i) => (
|
||||
<SkeletonLine key={i} width={width} height={height} className={className} />
|
||||
))}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export function SkeletonCard() {
|
||||
return (
|
||||
<div className="skeleton-card" aria-hidden="true">
|
||||
<Skeleton width="40%" height="0.75rem" />
|
||||
<Skeleton width="70%" height="1.2rem" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export function SkeletonGrid({ count = 6 }: { count?: number }) {
|
||||
return (
|
||||
<div className="status-grid" role="status" aria-label="Loading">
|
||||
{Array.from({ length: count }, (_, i) => (
|
||||
<SkeletonCard key={i} />
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
24
frontend/src/components/Toast.test.tsx
Normal file
24
frontend/src/components/Toast.test.tsx
Normal file
@@ -0,0 +1,24 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { ToastProvider, useToast } from './Toast'
|
||||
import type { ReactNode } from 'react'
|
||||
|
||||
const wrapper = ({ children }: { children: ReactNode }) => <ToastProvider>{children}</ToastProvider>
|
||||
|
||||
describe('Toast', () => {
|
||||
it('shows toast message', () => {
|
||||
function TestComponent() {
|
||||
const { toast } = useToast()
|
||||
return <button onClick={() => toast('Test message', 'success')}>Show</button>
|
||||
}
|
||||
render(<ToastProvider><TestComponent /></ToastProvider>)
|
||||
act(() => { screen.getByText('Show').click() })
|
||||
expect(screen.getByText('Test message')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('provides toast function via hook', () => {
|
||||
const { result } = renderHook(() => useToast(), { wrapper })
|
||||
expect(typeof result.current.toast).toBe('function')
|
||||
})
|
||||
})
|
||||
84
frontend/src/components/VirtualMessages.tsx
Normal file
84
frontend/src/components/VirtualMessages.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
import { useRef, useEffect, useCallback, useState } from 'react'
|
||||
import type { FinalResponse } from '../types'
|
||||
import { ChatMessage } from './ChatMessage'
|
||||
|
||||
interface Message {
|
||||
role: 'user' | 'assistant'
|
||||
content: string
|
||||
data?: FinalResponse
|
||||
}
|
||||
|
||||
interface VirtualMessagesProps {
|
||||
messages: Message[]
|
||||
viewMode: string
|
||||
loading: boolean
|
||||
onEditMessage?: (index: number) => void
|
||||
onDeleteMessage?: (index: number) => void
|
||||
}
|
||||
|
||||
const BUFFER = 10
|
||||
const BATCH_SIZE = 30
|
||||
|
||||
export function VirtualMessages({ messages, viewMode, loading, onEditMessage, onDeleteMessage }: VirtualMessagesProps) {
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
const endRef = useRef<HTMLDivElement>(null)
|
||||
const [visibleStart, setVisibleStart] = useState(0)
|
||||
|
||||
useEffect(() => {
|
||||
const start = Math.max(0, messages.length - BATCH_SIZE)
|
||||
setVisibleStart(start)
|
||||
}, [messages.length])
|
||||
|
||||
useEffect(() => {
|
||||
endRef.current?.scrollIntoView({ behavior: 'smooth' })
|
||||
}, [messages.length])
|
||||
|
||||
const handleScroll = useCallback(() => {
|
||||
const el = containerRef.current
|
||||
if (!el) return
|
||||
if (el.scrollTop < 100 && visibleStart > 0) {
|
||||
setVisibleStart((s) => Math.max(0, s - BUFFER))
|
||||
}
|
||||
}, [visibleStart])
|
||||
|
||||
const visibleMessages = messages.slice(visibleStart)
|
||||
|
||||
return (
|
||||
<div
|
||||
className="messages"
|
||||
ref={containerRef}
|
||||
onScroll={handleScroll}
|
||||
role="log"
|
||||
aria-label="Conversation"
|
||||
aria-live="polite"
|
||||
>
|
||||
{visibleStart > 0 && (
|
||||
<button
|
||||
className="load-more-btn"
|
||||
onClick={() => setVisibleStart((s) => Math.max(0, s - BATCH_SIZE))}
|
||||
>
|
||||
Load {Math.min(BATCH_SIZE, visibleStart)} earlier messages
|
||||
</button>
|
||||
)}
|
||||
{visibleMessages.map((msg, i) => {
|
||||
const realIndex = visibleStart + i
|
||||
return (
|
||||
<ChatMessage
|
||||
key={realIndex}
|
||||
message={msg}
|
||||
viewMode={viewMode}
|
||||
onEdit={msg.role === 'user' && onEditMessage ? () => onEditMessage(realIndex) : undefined}
|
||||
onDelete={onDeleteMessage ? () => onDeleteMessage(realIndex) : undefined}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
{loading && (
|
||||
<div className="loading-indicator" role="status" aria-live="assertive">
|
||||
<div className="loading-dots" aria-hidden="true"><span /><span /><span /></div>
|
||||
<span>Heads analyzing...</span>
|
||||
</div>
|
||||
)}
|
||||
<div ref={endRef} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
56
frontend/src/e2e.test.tsx
Normal file
56
frontend/src/e2e.test.tsx
Normal file
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* End-to-end smoke tests for FusionAGI frontend.
|
||||
*
|
||||
* These tests verify that major UI components render correctly
|
||||
* and basic navigation/interaction flows work.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { render, fireEvent } from '@testing-library/react'
|
||||
import App from './App'
|
||||
|
||||
// Mock fetch for API calls
|
||||
globalThis.fetch = vi.fn(() =>
|
||||
Promise.resolve({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ status: 'ok' }),
|
||||
text: () => Promise.resolve(''),
|
||||
} as Response)
|
||||
)
|
||||
|
||||
beforeEach(() => {
|
||||
// Set auth token so app renders main interface instead of login
|
||||
localStorage.setItem('fusionagi-token', 'test-token')
|
||||
})
|
||||
|
||||
describe('E2E Smoke Tests', () => {
|
||||
it('renders the main chat interface when authenticated', () => {
|
||||
const { container } = render(<App />)
|
||||
expect(container.querySelector('.app')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('renders the logo', () => {
|
||||
const { container } = render(<App />)
|
||||
expect(container.querySelector('.logo')).toBeTruthy()
|
||||
expect(container.querySelector('.logo')?.textContent).toBe('FusionAGI')
|
||||
})
|
||||
|
||||
it('has a prompt input', () => {
|
||||
const { container } = render(<App />)
|
||||
const input = container.querySelector('input[aria-label="Message input"]')
|
||||
expect(input).toBeTruthy()
|
||||
})
|
||||
|
||||
it('renders navigation tabs', () => {
|
||||
const { container } = render(<App />)
|
||||
const nav = container.querySelector('[role="tablist"]')
|
||||
expect(nav).toBeTruthy()
|
||||
})
|
||||
|
||||
it('shows login page when not authenticated', () => {
|
||||
localStorage.removeItem('fusionagi-token')
|
||||
const { container } = render(<App />)
|
||||
const loginPage = container.querySelector('.login-page, form, input')
|
||||
expect(loginPage).toBeTruthy()
|
||||
})
|
||||
})
|
||||
47
frontend/src/hooks/useChatHistory.test.ts
Normal file
47
frontend/src/hooks/useChatHistory.test.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { describe, it, expect, beforeEach } from 'vitest'
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { useChatHistory } from './useChatHistory'
|
||||
|
||||
describe('useChatHistory', () => {
|
||||
beforeEach(() => {
|
||||
localStorage.clear()
|
||||
})
|
||||
|
||||
it('starts empty', () => {
|
||||
const { result } = renderHook(() => useChatHistory())
|
||||
expect(result.current.messages).toEqual([])
|
||||
})
|
||||
|
||||
it('adds messages', () => {
|
||||
const { result } = renderHook(() => useChatHistory())
|
||||
act(() => { result.current.addMessage('user', 'Hello') })
|
||||
expect(result.current.messages.length).toBe(1)
|
||||
expect(result.current.messages[0].role).toBe('user')
|
||||
expect(result.current.messages[0].content).toBe('Hello')
|
||||
})
|
||||
|
||||
it('deletes messages', () => {
|
||||
const { result } = renderHook(() => useChatHistory())
|
||||
act(() => { result.current.addMessage('user', 'First') })
|
||||
act(() => { result.current.addMessage('assistant', 'Second') })
|
||||
expect(result.current.messages.length).toBe(2)
|
||||
act(() => { result.current.deleteMessage(0) })
|
||||
expect(result.current.messages.length).toBe(1)
|
||||
expect(result.current.messages[0].content).toBe('Second')
|
||||
})
|
||||
|
||||
it('clears history', () => {
|
||||
const { result } = renderHook(() => useChatHistory())
|
||||
act(() => { result.current.addMessage('user', 'Test') })
|
||||
act(() => { result.current.clearHistory() })
|
||||
expect(result.current.messages).toEqual([])
|
||||
})
|
||||
|
||||
it('persists to localStorage', () => {
|
||||
const { result } = renderHook(() => useChatHistory())
|
||||
act(() => { result.current.addMessage('user', 'Persisted') })
|
||||
const stored = localStorage.getItem('fusionagi-chat-history')
|
||||
expect(stored).toBeTruthy()
|
||||
expect(JSON.parse(stored!)[0].content).toBe('Persisted')
|
||||
})
|
||||
})
|
||||
69
frontend/src/hooks/useChatHistory.ts
Normal file
69
frontend/src/hooks/useChatHistory.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { useState, useCallback, useEffect } from 'react'
|
||||
import type { FinalResponse } from '../types'
|
||||
|
||||
interface ChatMessage {
|
||||
role: 'user' | 'assistant'
|
||||
content: string
|
||||
data?: FinalResponse
|
||||
id: string
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
const STORAGE_KEY = 'fusionagi-chat-history'
|
||||
const MAX_MESSAGES = 500
|
||||
|
||||
function generateId(): string {
|
||||
return `${Date.now()}-${Math.random().toString(36).slice(2, 9)}`
|
||||
}
|
||||
|
||||
function loadHistory(): ChatMessage[] {
|
||||
try {
|
||||
const raw = localStorage.getItem(STORAGE_KEY)
|
||||
if (!raw) return []
|
||||
return JSON.parse(raw)
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
function saveHistory(messages: ChatMessage[]) {
|
||||
try {
|
||||
const trimmed = messages.slice(-MAX_MESSAGES)
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(trimmed))
|
||||
} catch { /* storage full */ }
|
||||
}
|
||||
|
||||
export function useChatHistory() {
|
||||
const [messages, setMessages] = useState<ChatMessage[]>(() => loadHistory())
|
||||
|
||||
useEffect(() => {
|
||||
saveHistory(messages)
|
||||
}, [messages])
|
||||
|
||||
const addMessage = useCallback((role: 'user' | 'assistant', content: string, data?: FinalResponse) => {
|
||||
const msg: ChatMessage = { role, content, data, id: generateId(), timestamp: Date.now() }
|
||||
setMessages((prev) => [...prev, msg])
|
||||
return msg
|
||||
}, [])
|
||||
|
||||
const editMessage = useCallback((index: number, newContent: string) => {
|
||||
setMessages((prev) => {
|
||||
const updated = [...prev]
|
||||
if (updated[index] && updated[index].role === 'user') {
|
||||
updated[index] = { ...updated[index], content: newContent }
|
||||
}
|
||||
return updated
|
||||
})
|
||||
}, [])
|
||||
|
||||
const deleteMessage = useCallback((index: number) => {
|
||||
setMessages((prev) => prev.filter((_, i) => i !== index))
|
||||
}, [])
|
||||
|
||||
const clearHistory = useCallback(() => {
|
||||
setMessages([])
|
||||
localStorage.removeItem(STORAGE_KEY)
|
||||
}, [])
|
||||
|
||||
return { messages, addMessage, editMessage, deleteMessage, clearHistory, setMessages }
|
||||
}
|
||||
44
frontend/src/hooks/useKeyboard.ts
Normal file
44
frontend/src/hooks/useKeyboard.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import { useEffect, useCallback } from 'react'
|
||||
|
||||
interface KeyboardShortcuts {
|
||||
onSend?: () => void
|
||||
onSearch?: () => void
|
||||
onDismiss?: () => void
|
||||
onToggleTheme?: () => void
|
||||
}
|
||||
|
||||
export function useKeyboard({ onSend, onSearch, onDismiss, onToggleTheme }: KeyboardShortcuts) {
|
||||
const handler = useCallback((e: KeyboardEvent) => {
|
||||
const meta = e.metaKey || e.ctrlKey
|
||||
const target = e.target as HTMLElement
|
||||
const isInput = target.tagName === 'INPUT' || target.tagName === 'TEXTAREA' || target.isContentEditable
|
||||
|
||||
if (e.key === 'Escape') {
|
||||
onDismiss?.()
|
||||
return
|
||||
}
|
||||
|
||||
if (meta && e.key === 'Enter' && onSend) {
|
||||
e.preventDefault()
|
||||
onSend()
|
||||
return
|
||||
}
|
||||
|
||||
if (meta && e.key === 'k' && onSearch) {
|
||||
e.preventDefault()
|
||||
onSearch()
|
||||
return
|
||||
}
|
||||
|
||||
if (meta && e.key === 'j' && onToggleTheme && !isInput) {
|
||||
e.preventDefault()
|
||||
onToggleTheme()
|
||||
return
|
||||
}
|
||||
}, [onSend, onSearch, onDismiss, onToggleTheme])
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener('keydown', handler)
|
||||
return () => window.removeEventListener('keydown', handler)
|
||||
}, [handler])
|
||||
}
|
||||
@@ -3,21 +3,43 @@ import type { WSEvent } from '../types'
|
||||
|
||||
type WSStatus = 'disconnected' | 'connecting' | 'connected' | 'error'
|
||||
|
||||
const MAX_RETRIES = 10
|
||||
const BASE_DELAY = 1000
|
||||
|
||||
export function useWebSocket(sessionId: string | null) {
|
||||
const [status, setStatus] = useState<WSStatus>('disconnected')
|
||||
const [events, setEvents] = useState<WSEvent[]>([])
|
||||
const wsRef = useRef<WebSocket | null>(null)
|
||||
const retryCount = useRef(0)
|
||||
const retryTimer = useRef<ReturnType<typeof setTimeout> | null>(null)
|
||||
const shouldReconnect = useRef(true)
|
||||
|
||||
const connect = useCallback((sid: string) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) return
|
||||
if (wsRef.current) wsRef.current.close()
|
||||
shouldReconnect.current = true
|
||||
setStatus('connecting')
|
||||
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
|
||||
const ws = new WebSocket(`${protocol}//${window.location.host}/v1/sessions/${sid}/stream`)
|
||||
wsRef.current = ws
|
||||
|
||||
ws.onopen = () => setStatus('connected')
|
||||
ws.onclose = () => setStatus('disconnected')
|
||||
ws.onopen = () => {
|
||||
setStatus('connected')
|
||||
retryCount.current = 0
|
||||
}
|
||||
|
||||
ws.onclose = () => {
|
||||
setStatus('disconnected')
|
||||
if (shouldReconnect.current && retryCount.current < MAX_RETRIES) {
|
||||
const delay = BASE_DELAY * Math.pow(2, retryCount.current) + Math.random() * 500
|
||||
retryCount.current++
|
||||
retryTimer.current = setTimeout(() => connect(sid), delay)
|
||||
}
|
||||
}
|
||||
|
||||
ws.onerror = () => setStatus('error')
|
||||
|
||||
ws.onmessage = (e) => {
|
||||
try {
|
||||
const event: WSEvent = JSON.parse(e.data)
|
||||
@@ -33,14 +55,23 @@ export function useWebSocket(sessionId: string | null) {
|
||||
}, [])
|
||||
|
||||
const disconnect = useCallback(() => {
|
||||
shouldReconnect.current = false
|
||||
if (retryTimer.current) clearTimeout(retryTimer.current)
|
||||
wsRef.current?.close()
|
||||
wsRef.current = null
|
||||
setStatus('disconnected')
|
||||
retryCount.current = 0
|
||||
}, [])
|
||||
|
||||
const clearEvents = useCallback(() => setEvents([]), [])
|
||||
|
||||
useEffect(() => () => { wsRef.current?.close() }, [])
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
shouldReconnect.current = false
|
||||
if (retryTimer.current) clearTimeout(retryTimer.current)
|
||||
wsRef.current?.close()
|
||||
}
|
||||
}, [])
|
||||
|
||||
return { status, events, connect, send, disconnect, clearEvents }
|
||||
}
|
||||
|
||||
27
fusionagi/adapters/stt.py
Normal file
27
fusionagi/adapters/stt.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""STT adapter factory for VoiceManager integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from fusionagi.adapters.stt_adapter import STTAdapter, StubSTTAdapter
|
||||
|
||||
|
||||
def get_stt_adapter(provider: str = "stub") -> STTAdapter:
|
||||
"""Get an STT adapter for the given provider name.
|
||||
|
||||
Args:
|
||||
provider: Provider identifier (stub, whisper, azure).
|
||||
|
||||
Returns:
|
||||
Configured STTAdapter instance.
|
||||
"""
|
||||
if provider == "whisper":
|
||||
try:
|
||||
from fusionagi.adapters.stt_adapter import WhisperSTTAdapter
|
||||
api_key = os.environ.get("OPENAI_API_KEY", "")
|
||||
if api_key:
|
||||
return WhisperSTTAdapter(api_key=api_key)
|
||||
except ImportError:
|
||||
pass
|
||||
return StubSTTAdapter()
|
||||
24
fusionagi/adapters/tts.py
Normal file
24
fusionagi/adapters/tts.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""TTS adapter factory for VoiceManager integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from fusionagi.adapters.tts_adapter import ElevenLabsTTSAdapter, StubTTSAdapter, TTSAdapter
|
||||
|
||||
|
||||
def get_tts_adapter(provider: str = "stub") -> TTSAdapter:
|
||||
"""Get a TTS adapter for the given provider name.
|
||||
|
||||
Args:
|
||||
provider: Provider identifier (stub, elevenlabs, system).
|
||||
|
||||
Returns:
|
||||
Configured TTSAdapter instance.
|
||||
"""
|
||||
if provider == "elevenlabs":
|
||||
api_key = os.environ.get("ELEVENLABS_API_KEY", "")
|
||||
if api_key:
|
||||
return ElevenLabsTTSAdapter(api_key=api_key)
|
||||
return StubTTSAdapter()
|
||||
return StubTTSAdapter()
|
||||
@@ -167,6 +167,26 @@ def create_app(
|
||||
def metrics_endpoint() -> dict[str, Any]:
|
||||
return get_metrics().snapshot()
|
||||
|
||||
# Health check endpoints (no auth required)
|
||||
_start_time = time.time()
|
||||
|
||||
@app.get("/health", tags=["monitoring"])
|
||||
def health_check() -> dict[str, Any]:
|
||||
"""Basic health check for load balancer probes."""
|
||||
return {"status": "healthy", "uptime_seconds": round(time.time() - _start_time, 1)}
|
||||
|
||||
@app.get("/ready", tags=["monitoring"])
|
||||
def readiness_check() -> dict[str, Any]:
|
||||
"""Readiness probe. Returns 503 if not initialized."""
|
||||
ready = getattr(app.state, "_dvadasa_ready", False)
|
||||
if not ready:
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse( # type: ignore[return-value]
|
||||
content={"status": "not_ready"},
|
||||
status_code=503,
|
||||
)
|
||||
return {"status": "ready", "uptime_seconds": round(time.time() - _start_time, 1)}
|
||||
|
||||
# Version info endpoint
|
||||
@app.get("/version", tags=["meta"])
|
||||
def version_info() -> dict[str, Any]:
|
||||
|
||||
61
fusionagi/api/cache.py
Normal file
61
fusionagi/api/cache.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""In-memory response cache with TTL for the FusionAGI API."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ResponseCache:
|
||||
"""LRU-like response cache with configurable TTL.
|
||||
|
||||
For production, replace with Redis-backed cache.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 1000, ttl_seconds: float = 300.0) -> None:
|
||||
self._cache: dict[str, tuple[float, Any]] = {}
|
||||
self._max_size = max_size
|
||||
self._ttl = ttl_seconds
|
||||
|
||||
@staticmethod
|
||||
def _make_key(prompt: str, session_id: str, tenant_id: str = "default") -> str:
|
||||
"""Generate a cache key from prompt + session context."""
|
||||
raw = json.dumps({"prompt": prompt, "session": session_id, "tenant": tenant_id}, sort_keys=True)
|
||||
return hashlib.sha256(raw.encode()).hexdigest()
|
||||
|
||||
def get(self, prompt: str, session_id: str, tenant_id: str = "default") -> Any | None:
|
||||
"""Get cached response if it exists and hasn't expired."""
|
||||
key = self._make_key(prompt, session_id, tenant_id)
|
||||
entry = self._cache.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
ts, value = entry
|
||||
if time.time() - ts > self._ttl:
|
||||
del self._cache[key]
|
||||
return None
|
||||
return value
|
||||
|
||||
def set(self, prompt: str, session_id: str, value: Any, tenant_id: str = "default") -> None:
|
||||
"""Cache a response."""
|
||||
if len(self._cache) >= self._max_size:
|
||||
oldest_key = min(self._cache, key=lambda k: self._cache[k][0])
|
||||
del self._cache[oldest_key]
|
||||
key = self._make_key(prompt, session_id, tenant_id)
|
||||
self._cache[key] = (time.time(), value)
|
||||
|
||||
def invalidate(self, prompt: str, session_id: str, tenant_id: str = "default") -> bool:
|
||||
"""Remove a specific cache entry."""
|
||||
key = self._make_key(prompt, session_id, tenant_id)
|
||||
return self._cache.pop(key, None) is not None
|
||||
|
||||
def clear(self) -> int:
|
||||
"""Clear all cache entries. Returns count of cleared entries."""
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
return count
|
||||
|
||||
def stats(self) -> dict[str, int]:
|
||||
"""Return cache statistics."""
|
||||
now = time.time()
|
||||
active = sum(1 for ts, _ in self._cache.values() if now - ts <= self._ttl)
|
||||
return {"total": len(self._cache), "active": active, "max_size": self._max_size}
|
||||
97
fusionagi/api/pool.py
Normal file
97
fusionagi/api/pool.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Connection pool for backend services."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class ConnectionProtocol(Protocol):
|
||||
"""Protocol for poolable connections."""
|
||||
|
||||
async def connect(self) -> None: ...
|
||||
async def close(self) -> None: ...
|
||||
def is_alive(self) -> bool: ...
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""Async connection pool with health checks and automatic recycling.
|
||||
|
||||
Generic pool for database connections, HTTP clients, or any poolable resource.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factory: Any,
|
||||
min_size: int = 2,
|
||||
max_size: int = 10,
|
||||
max_idle_seconds: float = 300.0,
|
||||
) -> None:
|
||||
self._factory = factory
|
||||
self._min_size = min_size
|
||||
self._max_size = max_size
|
||||
self._max_idle = max_idle_seconds
|
||||
self._available: asyncio.Queue[Any] = asyncio.Queue(maxsize=max_size)
|
||||
self._in_use: int = 0
|
||||
self._total_created: int = 0
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Pre-populate pool with min_size connections."""
|
||||
if self._initialized:
|
||||
return
|
||||
for _ in range(self._min_size):
|
||||
conn = await self._create_connection()
|
||||
await self._available.put(conn)
|
||||
self._initialized = True
|
||||
|
||||
async def _create_connection(self) -> Any:
|
||||
"""Create a new connection via the factory."""
|
||||
conn = self._factory()
|
||||
if hasattr(conn, 'connect'):
|
||||
await conn.connect()
|
||||
self._total_created += 1
|
||||
return conn
|
||||
|
||||
async def acquire(self) -> Any:
|
||||
"""Acquire a connection from the pool."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
conn = self._available.get_nowait()
|
||||
if hasattr(conn, 'is_alive') and not conn.is_alive():
|
||||
conn = await self._create_connection()
|
||||
except asyncio.QueueEmpty:
|
||||
if self._in_use + self._available.qsize() < self._max_size:
|
||||
conn = await self._create_connection()
|
||||
else:
|
||||
conn = await self._available.get()
|
||||
|
||||
self._in_use += 1
|
||||
return conn
|
||||
|
||||
async def release(self, conn: Any) -> None:
|
||||
"""Return a connection to the pool."""
|
||||
self._in_use -= 1
|
||||
try:
|
||||
self._available.put_nowait(conn)
|
||||
except asyncio.QueueFull:
|
||||
if hasattr(conn, 'close'):
|
||||
await conn.close()
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
while not self._available.empty():
|
||||
conn = self._available.get_nowait()
|
||||
if hasattr(conn, 'close'):
|
||||
await conn.close()
|
||||
self._initialized = False
|
||||
self._in_use = 0
|
||||
|
||||
def stats(self) -> dict[str, int]:
|
||||
"""Return pool statistics."""
|
||||
return {
|
||||
"available": self._available.qsize(),
|
||||
"in_use": self._in_use,
|
||||
"total_created": self._total_created,
|
||||
"max_size": self._max_size,
|
||||
}
|
||||
@@ -29,7 +29,17 @@ def _ensure_init():
|
||||
|
||||
@router.post("")
|
||||
def create_session(user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Create a new session."""
|
||||
"""Create a new FusionAGI session.
|
||||
|
||||
Returns a session_id that can be used for subsequent prompts.
|
||||
Each session maintains its own conversation history and context.
|
||||
|
||||
Args:
|
||||
user_id: Optional user identifier for tenant-scoped sessions.
|
||||
|
||||
Returns:
|
||||
JSON with session_id and user_id.
|
||||
"""
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
if not store:
|
||||
@@ -41,7 +51,22 @@ def create_session(user_id: str | None = None) -> dict[str, Any]:
|
||||
|
||||
@router.post("/{session_id}/prompt")
|
||||
def submit_prompt(session_id: str, body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Submit a prompt and receive FinalResponse (sync)."""
|
||||
"""Submit a prompt to the 12-headed Dvādaśa pipeline.
|
||||
|
||||
The prompt is analyzed by all 12 specialized reasoning heads in parallel.
|
||||
Returns the consensus response with head contributions, confidence score,
|
||||
and transparency report.
|
||||
|
||||
Supports commands: /head <name>, /show dissent, /sources, /explain.
|
||||
|
||||
Args:
|
||||
session_id: Active session identifier.
|
||||
body: JSON body with 'prompt' field.
|
||||
|
||||
Returns:
|
||||
FinalResponse with final_answer, head_contributions, confidence_score,
|
||||
and transparency_report.
|
||||
"""
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
orch = get_orchestrator()
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Header
|
||||
from fastapi import APIRouter, Header, HTTPException
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
@@ -13,6 +14,17 @@ router = APIRouter()
|
||||
|
||||
DEFAULT_TENANT = os.environ.get("FUSIONAGI_DEFAULT_TENANT", "default")
|
||||
|
||||
# In-memory tenant registry; for production, back with Postgres
|
||||
_tenant_store: dict[str, dict[str, Any]] = {
|
||||
DEFAULT_TENANT: {
|
||||
"id": DEFAULT_TENANT,
|
||||
"name": "Default Tenant",
|
||||
"status": "active",
|
||||
"created_at": time.time(),
|
||||
"config": {},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def resolve_tenant(x_tenant_id: str | None = Header(default=None)) -> str:
|
||||
"""Resolve tenant from X-Tenant-ID header or default."""
|
||||
@@ -21,32 +33,121 @@ def resolve_tenant(x_tenant_id: str | None = Header(default=None)) -> str:
|
||||
|
||||
@router.get("/tenants/current")
|
||||
def get_current_tenant(x_tenant_id: str | None = Header(default=None)) -> dict[str, Any]:
|
||||
"""Return the resolved tenant context."""
|
||||
"""Return the resolved tenant context.
|
||||
|
||||
The tenant is determined from the X-Tenant-ID header.
|
||||
Falls back to the default tenant if no header is provided.
|
||||
"""
|
||||
tid = resolve_tenant(x_tenant_id)
|
||||
return {
|
||||
"tenant_id": tid,
|
||||
"is_default": tid == DEFAULT_TENANT,
|
||||
"isolation_mode": "logical",
|
||||
"exists": tid in _tenant_store,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tenants")
|
||||
def list_tenants() -> dict[str, Any]:
|
||||
"""List known tenants (placeholder — in production, query tenant registry)."""
|
||||
return {
|
||||
"tenants": [
|
||||
{"id": DEFAULT_TENANT, "name": "Default Tenant", "status": "active"},
|
||||
],
|
||||
"total": 1,
|
||||
}
|
||||
"""List all registered tenants.
|
||||
|
||||
Returns:
|
||||
JSON with tenants array and total count.
|
||||
"""
|
||||
tenants = list(_tenant_store.values())
|
||||
return {"tenants": tenants, "total": len(tenants)}
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}")
|
||||
def get_tenant(tenant_id: str) -> dict[str, Any]:
|
||||
"""Get a specific tenant by ID.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier.
|
||||
|
||||
Returns:
|
||||
Tenant record.
|
||||
|
||||
Raises:
|
||||
404 if tenant not found.
|
||||
"""
|
||||
tenant = _tenant_store.get(tenant_id)
|
||||
if not tenant:
|
||||
raise HTTPException(status_code=404, detail=f"Tenant {tenant_id} not found")
|
||||
return tenant
|
||||
|
||||
|
||||
@router.post("/tenants")
|
||||
def create_tenant(body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Register a new tenant."""
|
||||
"""Register a new tenant.
|
||||
|
||||
Args:
|
||||
body: JSON with 'id' and optional 'name', 'config' fields.
|
||||
|
||||
Returns:
|
||||
Created tenant record.
|
||||
"""
|
||||
tenant_id = body.get("id", "")
|
||||
name = body.get("name", tenant_id)
|
||||
if not tenant_id:
|
||||
return {"error": "Tenant ID required"}
|
||||
raise HTTPException(status_code=400, detail="Tenant ID required")
|
||||
if tenant_id in _tenant_store:
|
||||
raise HTTPException(status_code=409, detail=f"Tenant {tenant_id} already exists")
|
||||
|
||||
name = body.get("name", tenant_id)
|
||||
config = body.get("config", {})
|
||||
tenant = {
|
||||
"id": tenant_id,
|
||||
"name": name,
|
||||
"status": "active",
|
||||
"created_at": time.time(),
|
||||
"config": config,
|
||||
}
|
||||
_tenant_store[tenant_id] = tenant
|
||||
logger.info("Tenant created", extra={"tenant_id": tenant_id, "name": name})
|
||||
return {"id": tenant_id, "name": name, "status": "active"}
|
||||
return tenant
|
||||
|
||||
|
||||
@router.put("/tenants/{tenant_id}")
|
||||
def update_tenant(tenant_id: str, body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Update tenant configuration.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier.
|
||||
body: JSON with fields to update (name, config, status).
|
||||
|
||||
Returns:
|
||||
Updated tenant record.
|
||||
"""
|
||||
tenant = _tenant_store.get(tenant_id)
|
||||
if not tenant:
|
||||
raise HTTPException(status_code=404, detail=f"Tenant {tenant_id} not found")
|
||||
|
||||
if "name" in body:
|
||||
tenant["name"] = body["name"]
|
||||
if "config" in body:
|
||||
tenant["config"] = body["config"]
|
||||
if "status" in body:
|
||||
tenant["status"] = body["status"]
|
||||
|
||||
logger.info("Tenant updated", extra={"tenant_id": tenant_id})
|
||||
return tenant
|
||||
|
||||
|
||||
@router.delete("/tenants/{tenant_id}")
|
||||
def deactivate_tenant(tenant_id: str) -> dict[str, Any]:
|
||||
"""Deactivate a tenant (soft delete).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier.
|
||||
|
||||
Returns:
|
||||
Confirmation with tenant status.
|
||||
"""
|
||||
if tenant_id == DEFAULT_TENANT:
|
||||
raise HTTPException(status_code=400, detail="Cannot deactivate default tenant")
|
||||
tenant = _tenant_store.get(tenant_id)
|
||||
if not tenant:
|
||||
raise HTTPException(status_code=404, detail=f"Tenant {tenant_id} not found")
|
||||
tenant["status"] = "inactive"
|
||||
logger.info("Tenant deactivated", extra={"tenant_id": tenant_id})
|
||||
return {"id": tenant_id, "status": "inactive"}
|
||||
|
||||
102
fusionagi/api/secret_rotation.py
Normal file
102
fusionagi/api/secret_rotation.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""API key rotation mechanism for FusionAGI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class APIKeyRecord(BaseModel):
|
||||
"""Record for a rotatable API key."""
|
||||
key_hash: str
|
||||
created_at: float = Field(default_factory=time.time)
|
||||
expires_at: float | None = None
|
||||
label: str = "default"
|
||||
active: bool = True
|
||||
|
||||
|
||||
class SecretRotator:
|
||||
"""Manages API key lifecycle: generation, rotation, and expiry.
|
||||
|
||||
Keys are stored as SHA-256 hashes for security.
|
||||
Supports multiple active keys for zero-downtime rotation.
|
||||
"""
|
||||
|
||||
def __init__(self, max_active_keys: int = 3) -> None:
|
||||
self._keys: list[APIKeyRecord] = []
|
||||
self._max_active = max_active_keys
|
||||
|
||||
@staticmethod
|
||||
def _hash_key(key: str) -> str:
|
||||
"""Hash a key using SHA-256."""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
def generate_key(self, label: str = "default", ttl_seconds: float | None = None) -> str:
|
||||
"""Generate a new API key and register it. Returns the plaintext key."""
|
||||
key = secrets.token_urlsafe(32)
|
||||
record = APIKeyRecord(
|
||||
key_hash=self._hash_key(key),
|
||||
label=label,
|
||||
expires_at=time.time() + ttl_seconds if ttl_seconds else None,
|
||||
)
|
||||
self._keys.append(record)
|
||||
self._enforce_max_active()
|
||||
return key
|
||||
|
||||
def validate_key(self, key: str) -> bool:
|
||||
"""Check if a key is valid (active and not expired)."""
|
||||
key_hash = self._hash_key(key)
|
||||
now = time.time()
|
||||
for record in self._keys:
|
||||
if record.key_hash == key_hash and record.active:
|
||||
if record.expires_at and now > record.expires_at:
|
||||
record.active = False
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
def rotate(self, label: str = "default", ttl_seconds: float | None = None) -> str:
|
||||
"""Rotate keys: generate new, keep previous active for overlap period."""
|
||||
return self.generate_key(label=label, ttl_seconds=ttl_seconds)
|
||||
|
||||
def revoke(self, key: str) -> bool:
|
||||
"""Revoke a specific key."""
|
||||
key_hash = self._hash_key(key)
|
||||
for record in self._keys:
|
||||
if record.key_hash == key_hash:
|
||||
record.active = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def revoke_expired(self) -> int:
|
||||
"""Deactivate all expired keys."""
|
||||
now = time.time()
|
||||
count = 0
|
||||
for record in self._keys:
|
||||
if record.active and record.expires_at and now > record.expires_at:
|
||||
record.active = False
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _enforce_max_active(self) -> None:
|
||||
"""Ensure we don't exceed max active keys."""
|
||||
active = [k for k in self._keys if k.active]
|
||||
while len(active) > self._max_active:
|
||||
active[0].active = False
|
||||
active = active[1:]
|
||||
|
||||
def list_keys(self) -> list[dict[str, Any]]:
|
||||
"""List all keys (without hashes)."""
|
||||
return [
|
||||
{
|
||||
"label": k.label,
|
||||
"active": k.active,
|
||||
"created_at": k.created_at,
|
||||
"expires_at": k.expires_at,
|
||||
}
|
||||
for k in self._keys
|
||||
]
|
||||
106
fusionagi/api/task_queue.py
Normal file
106
fusionagi/api/task_queue.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Async background task queue for long-running operations."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Background task status."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
"""Result of a background task."""
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
created_at: float = Field(default_factory=time.time)
|
||||
completed_at: float | None = None
|
||||
duration_ms: float | None = None
|
||||
|
||||
|
||||
class BackgroundTaskQueue:
|
||||
"""Async task queue for offloading long-running work.
|
||||
|
||||
Tasks are submitted and run concurrently via asyncio. Results are
|
||||
stored in-memory and queryable by task_id.
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent: int = 5, result_ttl: float = 3600.0) -> None:
|
||||
self._semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self._results: dict[str, TaskResult] = {}
|
||||
self._tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._result_ttl = result_ttl
|
||||
|
||||
def submit(
|
||||
self,
|
||||
fn: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args: Any,
|
||||
task_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Submit a coroutine to run in the background. Returns task_id."""
|
||||
tid = task_id or str(uuid.uuid4())
|
||||
self._results[tid] = TaskResult(task_id=tid, status=TaskStatus.PENDING)
|
||||
|
||||
async def _runner() -> None:
|
||||
async with self._semaphore:
|
||||
self._results[tid].status = TaskStatus.RUNNING
|
||||
start = time.time()
|
||||
try:
|
||||
result = await fn(*args, **kwargs)
|
||||
self._results[tid].result = result
|
||||
self._results[tid].status = TaskStatus.COMPLETED
|
||||
except Exception as e:
|
||||
self._results[tid].error = str(e)
|
||||
self._results[tid].status = TaskStatus.FAILED
|
||||
finally:
|
||||
self._results[tid].completed_at = time.time()
|
||||
self._results[tid].duration_ms = (time.time() - start) * 1000
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(_runner())
|
||||
self._tasks[tid] = task
|
||||
return tid
|
||||
|
||||
def get_status(self, task_id: str) -> TaskResult | None:
|
||||
"""Get the status and result of a task."""
|
||||
return self._results.get(task_id)
|
||||
|
||||
def cancel(self, task_id: str) -> bool:
|
||||
"""Cancel a pending or running task."""
|
||||
task = self._tasks.get(task_id)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
self._results[task_id].status = TaskStatus.CANCELLED
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_tasks(self, status: TaskStatus | None = None) -> list[TaskResult]:
|
||||
"""List all tasks, optionally filtered by status."""
|
||||
results = list(self._results.values())
|
||||
if status:
|
||||
results = [r for r in results if r.status == status]
|
||||
return results
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove completed tasks older than result_ttl."""
|
||||
now = time.time()
|
||||
expired = [
|
||||
tid for tid, r in self._results.items()
|
||||
if r.completed_at and (now - r.completed_at) > self._result_ttl
|
||||
]
|
||||
for tid in expired:
|
||||
del self._results[tid]
|
||||
self._tasks.pop(tid, None)
|
||||
return len(expired)
|
||||
64
fusionagi/api/tracing.py
Normal file
64
fusionagi/api/tracing.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Request tracing middleware for structured logging with correlation IDs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
trace_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("trace_id", default="")
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
"""Get current trace ID from context."""
|
||||
return trace_id_var.get() or ""
|
||||
|
||||
|
||||
def set_trace_id(trace_id: str) -> None:
|
||||
"""Set trace ID in current context."""
|
||||
trace_id_var.set(trace_id)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
"""Generate a new trace ID."""
|
||||
return str(uuid.uuid4())[:8]
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
"""ASGI middleware that sets/propagates request trace IDs.
|
||||
|
||||
Extracts trace ID from X-Request-ID header or generates a new one.
|
||||
Injects trace ID into response headers and logging context.
|
||||
"""
|
||||
|
||||
def __init__(self, app: Any, header_name: str = "X-Request-ID") -> None:
|
||||
self.app = app
|
||||
self.header_name = header_name.lower()
|
||||
|
||||
async def __call__(self, scope: dict[str, Any], receive: Any, send: Any) -> None:
|
||||
"""ASGI entrypoint."""
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = dict(scope.get("headers", []))
|
||||
trace_id = ""
|
||||
for k, v in headers.items():
|
||||
if isinstance(k, bytes) and k.decode("latin-1").lower() == self.header_name:
|
||||
trace_id = v.decode("latin-1") if isinstance(v, bytes) else str(v)
|
||||
break
|
||||
|
||||
if not trace_id:
|
||||
trace_id = generate_trace_id()
|
||||
|
||||
set_trace_id(trace_id)
|
||||
|
||||
async def send_with_trace(message: dict[str, Any]) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
headers_list = list(message.get("headers", []))
|
||||
headers_list.append((b"x-request-id", trace_id.encode()))
|
||||
headers_list.append((b"x-trace-id", trace_id.encode()))
|
||||
message["headers"] = headers_list
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_with_trace)
|
||||
@@ -318,12 +318,11 @@ class VoiceInterface(InterfaceAdapter):
|
||||
Returns:
|
||||
Audio data as bytes.
|
||||
"""
|
||||
# Integrate with TTS provider based on self.tts_provider
|
||||
# - system: Use OS TTS (pyttsx3, etc.)
|
||||
# - elevenlabs: Use ElevenLabs API
|
||||
# - azure: Use Azure Cognitive Services
|
||||
# - google: Use Google Cloud TTS
|
||||
raise NotImplementedError("TTS provider integration required")
|
||||
from fusionagi.adapters.tts import get_tts_adapter
|
||||
|
||||
adapter = get_tts_adapter(self.tts_provider)
|
||||
voice_id = voice.voice_id if voice else None
|
||||
return await adapter.synthesize(text, voice_id=voice_id)
|
||||
|
||||
async def _transcribe_speech(self, audio_data: bytes) -> str:
|
||||
"""
|
||||
@@ -335,9 +334,7 @@ class VoiceInterface(InterfaceAdapter):
|
||||
Returns:
|
||||
Transcribed text.
|
||||
"""
|
||||
# Integrate with STT provider based on self.stt_provider
|
||||
# - whisper: Use OpenAI Whisper (local or API)
|
||||
# - azure: Use Azure Cognitive Services
|
||||
# - google: Use Google Cloud Speech-to-Text
|
||||
# - deepgram: Use Deepgram API
|
||||
raise NotImplementedError("STT provider integration required")
|
||||
from fusionagi.adapters.stt import get_stt_adapter
|
||||
|
||||
adapter = get_stt_adapter(self.stt_provider)
|
||||
return await adapter.transcribe(audio_data)
|
||||
|
||||
@@ -46,15 +46,20 @@ class GeometryAuthorityInterface(ABC):
|
||||
|
||||
|
||||
class InMemoryGeometryKernel(GeometryAuthorityInterface):
|
||||
"""
|
||||
In-memory lineage model; no concrete CAD kernel.
|
||||
Only tracks features registered via add_feature; validate_no_orphans returns []
|
||||
since every stored feature has lineage. For a kernel that tracks all feature ids
|
||||
separately, override validate_no_orphans to return ids not in lineage.
|
||||
"""In-memory geometry lineage model with orphan detection.
|
||||
|
||||
Tracks both registered features (with lineage) and all known feature IDs.
|
||||
Features added via ``register_feature_id`` without a corresponding
|
||||
``add_feature`` call are considered orphans.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lineage: dict[str, FeatureLineageEntry] = {}
|
||||
self._all_feature_ids: set[str] = set()
|
||||
|
||||
def register_feature_id(self, feature_id: str) -> None:
|
||||
"""Register a feature ID from the geometry model (may not have lineage yet)."""
|
||||
self._all_feature_ids.add(feature_id)
|
||||
|
||||
def add_feature(
|
||||
self,
|
||||
@@ -71,11 +76,27 @@ class InMemoryGeometryKernel(GeometryAuthorityInterface):
|
||||
process_eligible=process_eligible,
|
||||
)
|
||||
self._lineage[feature_id] = entry
|
||||
self._all_feature_ids.add(feature_id)
|
||||
return entry
|
||||
|
||||
def get_lineage(self, feature_id: str) -> FeatureLineageEntry | None:
|
||||
return self._lineage.get(feature_id)
|
||||
|
||||
def remove_feature(self, feature_id: str) -> bool:
|
||||
"""Remove a feature and its lineage."""
|
||||
removed = feature_id in self._lineage
|
||||
self._lineage.pop(feature_id, None)
|
||||
self._all_feature_ids.discard(feature_id)
|
||||
return removed
|
||||
|
||||
def validate_no_orphans(self) -> list[str]:
|
||||
"""Return []; this stub only tracks registered features, so none are orphans."""
|
||||
return []
|
||||
"""Return feature IDs that exist but have no valid lineage."""
|
||||
return [fid for fid in self._all_feature_ids if fid not in self._lineage]
|
||||
|
||||
def list_features(self) -> list[str]:
|
||||
"""Return all known feature IDs."""
|
||||
return sorted(self._all_feature_ids)
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return total feature count."""
|
||||
return len(self._all_feature_ids)
|
||||
|
||||
@@ -16,22 +16,49 @@ def _scoped_key(tenant_id: str, user_id: str, base: str) -> str:
|
||||
class VectorMemory:
|
||||
"""
|
||||
Vector memory for embeddings retrieval.
|
||||
Stub implementation; replace with pgvector or Pinecone adapter for production.
|
||||
|
||||
Uses in-memory cosine similarity search. For production, swap with
|
||||
pgvector, Pinecone, or Qdrant adapter behind the same interface.
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = 10000) -> None:
|
||||
self._store: list[dict[str, Any]] = []
|
||||
self._max_entries = max_entries
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
def add(self, id: str, embedding: list[float], metadata: dict[str, Any] | None = None) -> None:
|
||||
"""Add embedding (stub: stores in-memory)."""
|
||||
"""Add embedding to the vector store."""
|
||||
if len(self._store) >= self._max_entries:
|
||||
self._store.pop(0)
|
||||
self._store.append({"id": id, "embedding": embedding, "metadata": metadata or {}})
|
||||
|
||||
def search(self, query_embedding: list[float], top_k: int = 10) -> list[dict[str, Any]]:
|
||||
"""Search by embedding (stub: returns empty)."""
|
||||
return []
|
||||
"""Search by cosine similarity, returning top-k results."""
|
||||
scored = []
|
||||
for entry in self._store:
|
||||
sim = self._cosine_similarity(query_embedding, entry["embedding"])
|
||||
scored.append({"id": entry["id"], "metadata": entry["metadata"], "score": sim})
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored[:top_k]
|
||||
|
||||
def delete(self, id: str) -> bool:
|
||||
"""Remove an entry by ID."""
|
||||
before = len(self._store)
|
||||
self._store = [e for e in self._store if e["id"] != id]
|
||||
return len(self._store) < before
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return entry count."""
|
||||
return len(self._store)
|
||||
|
||||
|
||||
class MemoryService:
|
||||
|
||||
106
fusionagi/settings.py
Normal file
106
fusionagi/settings.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Environment-based configuration using Pydantic Settings.
|
||||
|
||||
All settings are configurable via environment variables or .env file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class APIConfig(BaseModel):
|
||||
"""API server configuration."""
|
||||
host: str = Field(default="0.0.0.0", description="Server bind host")
|
||||
port: int = Field(default=8000, description="Server bind port")
|
||||
workers: int = Field(default=1, description="Number of worker processes")
|
||||
cors_origins: list[str] = Field(default=["*"], description="CORS allowed origins")
|
||||
api_key: str | None = Field(default=None, description="API key for authentication")
|
||||
rate_limit: int = Field(default=120, description="Rate limit (requests per window)")
|
||||
rate_window: float = Field(default=60.0, description="Rate limit window in seconds")
|
||||
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
"""Database configuration."""
|
||||
url: str = Field(default="sqlite:///fusionagi.db", description="Database URL")
|
||||
pool_size: int = Field(default=5, description="Connection pool size")
|
||||
max_overflow: int = Field(default=10, description="Max overflow connections")
|
||||
echo: bool = Field(default=False, description="Echo SQL statements")
|
||||
|
||||
|
||||
class CacheConfig(BaseModel):
|
||||
"""Cache configuration."""
|
||||
enabled: bool = Field(default=True, description="Enable response caching")
|
||||
max_size: int = Field(default=1000, description="Max cached entries")
|
||||
ttl_seconds: float = Field(default=300.0, description="Cache TTL in seconds")
|
||||
backend: str = Field(default="memory", description="Cache backend (memory or redis)")
|
||||
redis_url: str | None = Field(default=None, description="Redis URL if backend is redis")
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
"""Logging configuration."""
|
||||
level: str = Field(default="INFO", description="Log level")
|
||||
format: str = Field(default="json", description="Log format (json or text)")
|
||||
correlation_id_header: str = Field(default="X-Request-ID", description="Request ID header")
|
||||
|
||||
|
||||
class GovernanceConfig(BaseModel):
|
||||
"""Governance configuration."""
|
||||
mode: str = Field(default="advisory", description="Governance mode (advisory or enforcing)")
|
||||
max_file_size: int | None = Field(default=None, description="Max file size in bytes (None=unlimited)")
|
||||
allow_private_urls: bool = Field(default=True, description="Allow private/internal URLs")
|
||||
|
||||
|
||||
class FusionAGIConfig(BaseModel):
|
||||
"""Root configuration for FusionAGI."""
|
||||
api: APIConfig = Field(default_factory=APIConfig)
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||
cache: CacheConfig = Field(default_factory=CacheConfig)
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig)
|
||||
governance: GovernanceConfig = Field(default_factory=GovernanceConfig)
|
||||
tenant_isolation: bool = Field(default=True, description="Enable tenant isolation")
|
||||
max_concurrent_tasks: int = Field(default=5, description="Max background tasks")
|
||||
|
||||
|
||||
def load_config() -> FusionAGIConfig:
|
||||
"""Load configuration from environment variables.
|
||||
|
||||
Environment variables are mapped using the pattern:
|
||||
FUSIONAGI_<SECTION>_<KEY> (e.g., FUSIONAGI_API_PORT=9000)
|
||||
"""
|
||||
import os
|
||||
config = FusionAGIConfig()
|
||||
|
||||
env_map = {
|
||||
"FUSIONAGI_API_HOST": ("api", "host"),
|
||||
"FUSIONAGI_API_PORT": ("api", "port"),
|
||||
"FUSIONAGI_API_WORKERS": ("api", "workers"),
|
||||
"FUSIONAGI_API_KEY": ("api", "api_key"),
|
||||
"FUSIONAGI_RATE_LIMIT": ("api", "rate_limit"),
|
||||
"FUSIONAGI_RATE_WINDOW": ("api", "rate_window"),
|
||||
"FUSIONAGI_DB_URL": ("database", "url"),
|
||||
"FUSIONAGI_DB_POOL_SIZE": ("database", "pool_size"),
|
||||
"FUSIONAGI_CACHE_ENABLED": ("cache", "enabled"),
|
||||
"FUSIONAGI_CACHE_TTL": ("cache", "ttl_seconds"),
|
||||
"FUSIONAGI_CACHE_BACKEND": ("cache", "backend"),
|
||||
"FUSIONAGI_REDIS_URL": ("cache", "redis_url"),
|
||||
"FUSIONAGI_LOG_LEVEL": ("logging", "level"),
|
||||
"FUSIONAGI_LOG_FORMAT": ("logging", "format"),
|
||||
"FUSIONAGI_GOVERNANCE_MODE": ("governance", "mode"),
|
||||
}
|
||||
|
||||
for env_var, (section, key) in env_map.items():
|
||||
value = os.environ.get(env_var)
|
||||
if value is not None:
|
||||
section_obj = getattr(config, section)
|
||||
field_info = type(section_obj).model_fields.get(key)
|
||||
if field_info and field_info.annotation:
|
||||
annotation = field_info.annotation
|
||||
if annotation is int:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
elif annotation is float:
|
||||
value = float(value) # type: ignore[assignment]
|
||||
elif annotation is bool:
|
||||
value = value.lower() in ("true", "1", "yes") # type: ignore[assignment]
|
||||
setattr(section_obj, key, value)
|
||||
|
||||
return config
|
||||
48
migrations/README.md
Normal file
48
migrations/README.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# Database Migrations
|
||||
|
||||
FusionAGI uses a lightweight migration system for schema changes.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
migrations/
|
||||
├── README.md
|
||||
├── versions/
|
||||
│ └── 001_initial_schema.sql
|
||||
└── migrate.py
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Run all pending migrations
|
||||
python -m migrations.migrate up
|
||||
|
||||
# Rollback the last migration
|
||||
python -m migrations.migrate down
|
||||
|
||||
# Show migration status
|
||||
python -m migrations.migrate status
|
||||
```
|
||||
|
||||
## Creating a Migration
|
||||
|
||||
1. Create a new SQL file in `migrations/versions/`:
|
||||
```
|
||||
NNN_description.sql
|
||||
```
|
||||
|
||||
2. Include both `-- UP` and `-- DOWN` sections:
|
||||
```sql
|
||||
-- UP
|
||||
CREATE TABLE example (...);
|
||||
|
||||
-- DOWN
|
||||
DROP TABLE IF EXISTS example;
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Migrations run in numeric order (001, 002, etc.)
|
||||
- Each migration is tracked in a `_migrations` table
|
||||
- For production, consider using Alembic with SQLAlchemy
|
||||
120
migrations/migrate.py
Normal file
120
migrations/migrate.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Lightweight database migration runner for FusionAGI.
|
||||
|
||||
Usage:
|
||||
python -m migrations.migrate up # Apply all pending migrations
|
||||
python -m migrations.migrate down # Rollback last migration
|
||||
python -m migrations.migrate status # Show migration status
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
VERSIONS_DIR = Path(__file__).parent / "versions"
|
||||
DEFAULT_DB = os.environ.get("FUSIONAGI_DB_PATH", "fusionagi.db")
|
||||
|
||||
|
||||
def get_connection(db_path: str = DEFAULT_DB) -> sqlite3.Connection:
|
||||
"""Get database connection and ensure migration tracking table exists."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS _migrations "
|
||||
"(id INTEGER PRIMARY KEY AUTOINCREMENT, version TEXT NOT NULL UNIQUE, "
|
||||
"applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)"
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def get_applied(conn: sqlite3.Connection) -> set[str]:
|
||||
"""Get set of applied migration versions."""
|
||||
rows = conn.execute("SELECT version FROM _migrations").fetchall()
|
||||
return {r[0] for r in rows}
|
||||
|
||||
|
||||
def get_migration_files() -> list[tuple[str, Path]]:
|
||||
"""Get sorted list of (version, path) tuples."""
|
||||
files = sorted(VERSIONS_DIR.glob("*.sql"))
|
||||
return [(f.stem, f) for f in files]
|
||||
|
||||
|
||||
def parse_migration(path: Path) -> tuple[str, str]:
|
||||
"""Parse a migration file into (up_sql, down_sql)."""
|
||||
text = path.read_text()
|
||||
parts = text.split("-- DOWN")
|
||||
up_sql = parts[0].replace("-- UP", "").strip()
|
||||
down_sql = parts[1].strip() if len(parts) > 1 else ""
|
||||
return up_sql, down_sql
|
||||
|
||||
|
||||
def migrate_up(db_path: str = DEFAULT_DB) -> int:
|
||||
"""Apply all pending migrations. Returns count applied."""
|
||||
conn = get_connection(db_path)
|
||||
applied = get_applied(conn)
|
||||
count = 0
|
||||
for version, path in get_migration_files():
|
||||
if version not in applied:
|
||||
up_sql, _ = parse_migration(path)
|
||||
conn.executescript(up_sql)
|
||||
conn.execute("INSERT INTO _migrations (version) VALUES (?)", (version,))
|
||||
conn.commit()
|
||||
print(f"Applied: {version}")
|
||||
count += 1
|
||||
if count == 0:
|
||||
print("No pending migrations.")
|
||||
return count
|
||||
|
||||
|
||||
def migrate_down(db_path: str = DEFAULT_DB) -> bool:
|
||||
"""Rollback the last applied migration."""
|
||||
conn = get_connection(db_path)
|
||||
applied = get_applied(conn)
|
||||
if not applied:
|
||||
print("No migrations to rollback.")
|
||||
return False
|
||||
|
||||
migrations = get_migration_files()
|
||||
applied_migrations = [(v, p) for v, p in migrations if v in applied]
|
||||
if not applied_migrations:
|
||||
print("No migrations to rollback.")
|
||||
return False
|
||||
|
||||
version, path = applied_migrations[-1]
|
||||
_, down_sql = parse_migration(path)
|
||||
if not down_sql:
|
||||
print(f"No DOWN section in {version}. Cannot rollback.")
|
||||
return False
|
||||
|
||||
conn.executescript(down_sql)
|
||||
try:
|
||||
conn.execute("DELETE FROM _migrations WHERE version = ?", (version,))
|
||||
except Exception:
|
||||
pass
|
||||
conn.commit()
|
||||
print(f"Rolled back: {version}")
|
||||
return True
|
||||
|
||||
|
||||
def show_status(db_path: str = DEFAULT_DB) -> None:
|
||||
"""Show migration status."""
|
||||
conn = get_connection(db_path)
|
||||
applied = get_applied(conn)
|
||||
for version, _ in get_migration_files():
|
||||
status = "applied" if version in applied else "pending"
|
||||
print(f" {version}: {status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cmd = sys.argv[1] if len(sys.argv) > 1 else "status"
|
||||
db = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_DB
|
||||
if cmd == "up":
|
||||
migrate_up(db)
|
||||
elif cmd == "down":
|
||||
migrate_down(db)
|
||||
elif cmd == "status":
|
||||
show_status(db)
|
||||
else:
|
||||
print(f"Unknown command: {cmd}. Use: up, down, status")
|
||||
55
migrations/versions/001_initial_schema.sql
Normal file
55
migrations/versions/001_initial_schema.sql
Normal file
@@ -0,0 +1,55 @@
|
||||
-- UP
|
||||
CREATE TABLE IF NOT EXISTS _migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
version TEXT NOT NULL UNIQUE,
|
||||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
tenant_id TEXT NOT NULL DEFAULT 'default',
|
||||
user_id TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ethical_lessons (
|
||||
id TEXT PRIMARY KEY,
|
||||
principle TEXT NOT NULL,
|
||||
description TEXT,
|
||||
weight REAL DEFAULT 1.0,
|
||||
source_task TEXT,
|
||||
outcome TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS consequences (
|
||||
id TEXT PRIMARY KEY,
|
||||
action_id TEXT NOT NULL,
|
||||
choice_made TEXT NOT NULL,
|
||||
alternatives TEXT,
|
||||
expected_risk REAL DEFAULT 0.0,
|
||||
expected_reward REAL DEFAULT 0.0,
|
||||
actual_outcome TEXT,
|
||||
surprise_factor REAL DEFAULT 0.0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tenants (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
config TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_tenant ON sessions(tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_consequences_action ON consequences(action_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_ethical_lessons_source ON ethical_lessons(source_task);
|
||||
|
||||
-- DOWN
|
||||
DROP TABLE IF EXISTS tenants;
|
||||
DROP TABLE IF EXISTS consequences;
|
||||
DROP TABLE IF EXISTS ethical_lessons;
|
||||
DROP TABLE IF EXISTS sessions;
|
||||
DROP TABLE IF EXISTS _migrations;
|
||||
74
monitoring/grafana-dashboard.json
Normal file
74
monitoring/grafana-dashboard.json
Normal file
@@ -0,0 +1,74 @@
|
||||
{
|
||||
"dashboard": {
|
||||
"title": "FusionAGI Dvādaśa",
|
||||
"description": "Performance monitoring for the 12-headed AGI orchestrator",
|
||||
"tags": ["fusionagi", "ai", "orchestration"],
|
||||
"timezone": "browser",
|
||||
"panels": [
|
||||
{
|
||||
"title": "HTTP Request Rate",
|
||||
"type": "timeseries",
|
||||
"gridPos": {"h": 8, "w": 12, "x": 0, "y": 0},
|
||||
"targets": [{"expr": "rate(http_requests_total[5m])", "legendFormat": "{{method}} {{path}}"}]
|
||||
},
|
||||
{
|
||||
"title": "Response Latency (p50/p95/p99)",
|
||||
"type": "timeseries",
|
||||
"gridPos": {"h": 8, "w": 12, "x": 12, "y": 0},
|
||||
"targets": [
|
||||
{"expr": "histogram_quantile(0.50, rate(http_request_duration_seconds_bucket[5m]))", "legendFormat": "p50"},
|
||||
{"expr": "histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m]))", "legendFormat": "p95"},
|
||||
{"expr": "histogram_quantile(0.99, rate(http_request_duration_seconds_bucket[5m]))", "legendFormat": "p99"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Error Rate",
|
||||
"type": "stat",
|
||||
"gridPos": {"h": 4, "w": 6, "x": 0, "y": 8},
|
||||
"targets": [{"expr": "sum(rate(http_responses_total{status=~\"5..\"}[5m])) / sum(rate(http_responses_total[5m]))"}]
|
||||
},
|
||||
{
|
||||
"title": "Active Sessions",
|
||||
"type": "stat",
|
||||
"gridPos": {"h": 4, "w": 6, "x": 6, "y": 8},
|
||||
"targets": [{"expr": "fusionagi_active_sessions"}]
|
||||
},
|
||||
{
|
||||
"title": "Head Analysis Duration",
|
||||
"type": "timeseries",
|
||||
"gridPos": {"h": 8, "w": 12, "x": 12, "y": 8},
|
||||
"targets": [{"expr": "histogram_quantile(0.95, rate(head_analysis_duration_seconds_bucket[5m]))", "legendFormat": "{{head}}"}]
|
||||
},
|
||||
{
|
||||
"title": "Consequence Engine Activity",
|
||||
"type": "timeseries",
|
||||
"gridPos": {"h": 8, "w": 12, "x": 0, "y": 16},
|
||||
"targets": [
|
||||
{"expr": "rate(consequence_choices_total[5m])", "legendFormat": "Choices"},
|
||||
{"expr": "rate(consequence_surprises_total[5m])", "legendFormat": "Surprises"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Cache Hit Rate",
|
||||
"type": "gauge",
|
||||
"gridPos": {"h": 4, "w": 6, "x": 12, "y": 16},
|
||||
"targets": [{"expr": "sum(rate(cache_hits_total[5m])) / (sum(rate(cache_hits_total[5m])) + sum(rate(cache_misses_total[5m])))"}]
|
||||
},
|
||||
{
|
||||
"title": "Connection Pool",
|
||||
"type": "stat",
|
||||
"gridPos": {"h": 4, "w": 6, "x": 18, "y": 16},
|
||||
"targets": [
|
||||
{"expr": "connection_pool_in_use", "legendFormat": "In Use"},
|
||||
{"expr": "connection_pool_available", "legendFormat": "Available"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"templating": {
|
||||
"list": [
|
||||
{"name": "datasource", "type": "datasource", "query": "prometheus"},
|
||||
{"name": "instance", "type": "query", "query": "label_values(up{job=\"fusionagi\"}, instance)"}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
64
tests/test_cache.py
Normal file
64
tests/test_cache.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Tests for response cache."""
|
||||
|
||||
import time
|
||||
|
||||
from fusionagi.api.cache import ResponseCache
|
||||
|
||||
|
||||
def test_cache_set_and_get():
|
||||
cache = ResponseCache(max_size=10, ttl_seconds=60)
|
||||
cache.set("hello", "s1", {"answer": "world"})
|
||||
result = cache.get("hello", "s1")
|
||||
assert result == {"answer": "world"}
|
||||
|
||||
|
||||
def test_cache_miss():
|
||||
cache = ResponseCache()
|
||||
assert cache.get("nonexistent", "s1") is None
|
||||
|
||||
|
||||
def test_cache_ttl_expiry():
|
||||
cache = ResponseCache(ttl_seconds=0.01)
|
||||
cache.set("prompt", "s1", "cached")
|
||||
time.sleep(0.02)
|
||||
assert cache.get("prompt", "s1") is None
|
||||
|
||||
|
||||
def test_cache_invalidate():
|
||||
cache = ResponseCache()
|
||||
cache.set("p", "s", "val")
|
||||
assert cache.invalidate("p", "s") is True
|
||||
assert cache.get("p", "s") is None
|
||||
|
||||
|
||||
def test_cache_clear():
|
||||
cache = ResponseCache()
|
||||
cache.set("a", "s", 1)
|
||||
cache.set("b", "s", 2)
|
||||
count = cache.clear()
|
||||
assert count == 2
|
||||
assert cache.get("a", "s") is None
|
||||
|
||||
|
||||
def test_cache_max_size():
|
||||
cache = ResponseCache(max_size=2)
|
||||
cache.set("a", "s", 1)
|
||||
cache.set("b", "s", 2)
|
||||
cache.set("c", "s", 3)
|
||||
assert cache.stats()["total"] == 2
|
||||
|
||||
|
||||
def test_cache_stats():
|
||||
cache = ResponseCache(max_size=100)
|
||||
cache.set("a", "s", 1)
|
||||
stats = cache.stats()
|
||||
assert stats["total"] == 1
|
||||
assert stats["max_size"] == 100
|
||||
|
||||
|
||||
def test_cache_tenant_isolation():
|
||||
cache = ResponseCache()
|
||||
cache.set("prompt", "s1", "tenant_a_result", tenant_id="a")
|
||||
cache.set("prompt", "s1", "tenant_b_result", tenant_id="b")
|
||||
assert cache.get("prompt", "s1", "a") == "tenant_a_result"
|
||||
assert cache.get("prompt", "s1", "b") == "tenant_b_result"
|
||||
30
tests/test_config.py
Normal file
30
tests/test_config.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Tests for environment-based configuration."""
|
||||
|
||||
from fusionagi.settings import FusionAGIConfig, load_config
|
||||
|
||||
|
||||
def test_default_config():
|
||||
config = FusionAGIConfig()
|
||||
assert config.api.host == "0.0.0.0"
|
||||
assert config.api.port == 8000
|
||||
assert config.api.rate_limit == 120
|
||||
assert config.database.url == "sqlite:///fusionagi.db"
|
||||
assert config.cache.enabled is True
|
||||
assert config.governance.mode == "advisory"
|
||||
|
||||
|
||||
def test_load_config_from_env(monkeypatch):
|
||||
monkeypatch.setenv("FUSIONAGI_API_PORT", "9000")
|
||||
monkeypatch.setenv("FUSIONAGI_LOG_LEVEL", "DEBUG")
|
||||
config = load_config()
|
||||
assert config.api.port == 9000
|
||||
assert config.logging.level == "DEBUG"
|
||||
|
||||
|
||||
def test_config_sections():
|
||||
config = FusionAGIConfig()
|
||||
assert hasattr(config, 'api')
|
||||
assert hasattr(config, 'database')
|
||||
assert hasattr(config, 'cache')
|
||||
assert hasattr(config, 'logging')
|
||||
assert hasattr(config, 'governance')
|
||||
65
tests/test_connection_pool.py
Normal file
65
tests/test_connection_pool.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Tests for connection pool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from fusionagi.api.pool import ConnectionPool
|
||||
|
||||
|
||||
class MockConnection:
|
||||
"""Mock connection for testing."""
|
||||
def __init__(self):
|
||||
self.connected = False
|
||||
self.closed = False
|
||||
|
||||
async def connect(self):
|
||||
self.connected = True
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
def is_alive(self):
|
||||
return self.connected and not self.closed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pool():
|
||||
return ConnectionPool(factory=MockConnection, min_size=2, max_size=5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(pool):
|
||||
await pool.initialize()
|
||||
stats = pool.stats()
|
||||
assert stats["available"] == 2
|
||||
assert stats["total_created"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_and_release(pool):
|
||||
await pool.initialize()
|
||||
conn = await pool.acquire()
|
||||
assert isinstance(conn, MockConnection)
|
||||
stats = pool.stats()
|
||||
assert stats["in_use"] == 1
|
||||
await pool.release(conn)
|
||||
stats = pool.stats()
|
||||
assert stats["in_use"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all(pool):
|
||||
await pool.initialize()
|
||||
await pool.close_all()
|
||||
stats = pool.stats()
|
||||
assert stats["available"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_size():
|
||||
pool = ConnectionPool(factory=MockConnection, min_size=1, max_size=2)
|
||||
await pool.initialize()
|
||||
c1 = await pool.acquire()
|
||||
c2 = await pool.acquire()
|
||||
assert pool.stats()["in_use"] == 2
|
||||
await pool.release(c1)
|
||||
await pool.release(c2)
|
||||
47
tests/test_migration.py
Normal file
47
tests/test_migration.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Tests for migration system."""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
|
||||
from migrations.migrate import migrate_down, migrate_up
|
||||
|
||||
|
||||
def test_migrate_up():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
try:
|
||||
count = migrate_up(db_path)
|
||||
assert count >= 1
|
||||
conn = sqlite3.connect(db_path)
|
||||
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||
table_names = [t[0] for t in tables]
|
||||
assert "sessions" in table_names
|
||||
assert "ethical_lessons" in table_names
|
||||
assert "consequences" in table_names
|
||||
conn.close()
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
def test_migrate_down():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
try:
|
||||
migrate_up(db_path)
|
||||
result = migrate_down(db_path)
|
||||
assert result is True
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
def test_migrate_idempotent():
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
try:
|
||||
count1 = migrate_up(db_path)
|
||||
count2 = migrate_up(db_path)
|
||||
assert count1 >= 1
|
||||
assert count2 == 0
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
65
tests/test_secret_rotation.py
Normal file
65
tests/test_secret_rotation.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Tests for secret rotation mechanism."""
|
||||
|
||||
import time
|
||||
|
||||
from fusionagi.api.secret_rotation import SecretRotator
|
||||
|
||||
|
||||
def test_generate_and_validate():
|
||||
rotator = SecretRotator()
|
||||
key = rotator.generate_key()
|
||||
assert rotator.validate_key(key) is True
|
||||
|
||||
|
||||
def test_invalid_key():
|
||||
rotator = SecretRotator()
|
||||
assert rotator.validate_key("invalid") is False
|
||||
|
||||
|
||||
def test_key_expiry():
|
||||
rotator = SecretRotator()
|
||||
key = rotator.generate_key(ttl_seconds=0.01)
|
||||
assert rotator.validate_key(key) is True
|
||||
time.sleep(0.02)
|
||||
assert rotator.validate_key(key) is False
|
||||
|
||||
|
||||
def test_revoke():
|
||||
rotator = SecretRotator()
|
||||
key = rotator.generate_key()
|
||||
assert rotator.revoke(key) is True
|
||||
assert rotator.validate_key(key) is False
|
||||
|
||||
|
||||
def test_rotate():
|
||||
rotator = SecretRotator()
|
||||
key1 = rotator.generate_key()
|
||||
key2 = rotator.rotate()
|
||||
assert rotator.validate_key(key1) is True
|
||||
assert rotator.validate_key(key2) is True
|
||||
|
||||
|
||||
def test_max_active_keys():
|
||||
rotator = SecretRotator(max_active_keys=2)
|
||||
key1 = rotator.generate_key()
|
||||
rotator.generate_key()
|
||||
rotator.generate_key()
|
||||
assert rotator.validate_key(key1) is False
|
||||
|
||||
|
||||
def test_list_keys():
|
||||
rotator = SecretRotator()
|
||||
rotator.generate_key(label="test")
|
||||
keys = rotator.list_keys()
|
||||
assert len(keys) == 1
|
||||
assert keys[0]["label"] == "test"
|
||||
assert "key_hash" not in keys[0]
|
||||
|
||||
|
||||
def test_revoke_expired():
|
||||
rotator = SecretRotator()
|
||||
rotator.generate_key(ttl_seconds=0.01)
|
||||
rotator.generate_key(ttl_seconds=100)
|
||||
time.sleep(0.02)
|
||||
count = rotator.revoke_expired()
|
||||
assert count == 1
|
||||
68
tests/test_task_queue.py
Normal file
68
tests/test_task_queue.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for background task queue."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from fusionagi.api.task_queue import BackgroundTaskQueue, TaskStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue():
|
||||
return BackgroundTaskQueue(max_concurrent=3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_and_complete(queue):
|
||||
async def work():
|
||||
await asyncio.sleep(0.01)
|
||||
return 42
|
||||
|
||||
tid = queue.submit(work)
|
||||
await asyncio.sleep(0.05)
|
||||
result = queue.get_status(tid)
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
assert result.result == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_task(queue):
|
||||
async def fail():
|
||||
raise ValueError("boom")
|
||||
|
||||
tid = queue.submit(fail)
|
||||
await asyncio.sleep(0.05)
|
||||
result = queue.get_status(tid)
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.FAILED
|
||||
assert "boom" in (result.error or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks(queue):
|
||||
async def noop():
|
||||
pass
|
||||
|
||||
queue.submit(noop)
|
||||
queue.submit(noop)
|
||||
await asyncio.sleep(0.05)
|
||||
tasks = queue.list_tasks()
|
||||
assert len(tasks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_filtered(queue):
|
||||
async def noop():
|
||||
pass
|
||||
|
||||
queue.submit(noop)
|
||||
await asyncio.sleep(0.05)
|
||||
completed = queue.list_tasks(status=TaskStatus.COMPLETED)
|
||||
assert len(completed) == 1
|
||||
pending = queue.list_tasks(status=TaskStatus.PENDING)
|
||||
assert len(pending) == 0
|
||||
|
||||
|
||||
def test_nonexistent_task(queue):
|
||||
assert queue.get_status("nonexistent") is None
|
||||
19
tests/test_tracing.py
Normal file
19
tests/test_tracing.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Tests for request tracing."""
|
||||
|
||||
from fusionagi.api.tracing import generate_trace_id, get_trace_id, set_trace_id
|
||||
|
||||
|
||||
def test_generate_trace_id():
|
||||
tid = generate_trace_id()
|
||||
assert len(tid) == 8
|
||||
assert isinstance(tid, str)
|
||||
|
||||
|
||||
def test_set_and_get_trace_id():
|
||||
set_trace_id("abc123")
|
||||
assert get_trace_id() == "abc123"
|
||||
|
||||
|
||||
def test_default_trace_id():
|
||||
set_trace_id("")
|
||||
assert get_trace_id() == ""
|
||||
56
tests/test_vector_memory.py
Normal file
56
tests/test_vector_memory.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Tests for vector memory with cosine similarity."""
|
||||
|
||||
from fusionagi.memory.service import VectorMemory
|
||||
|
||||
|
||||
def test_add_and_search():
|
||||
vm = VectorMemory()
|
||||
vm.add("doc1", [1.0, 0.0, 0.0], {"text": "hello"})
|
||||
vm.add("doc2", [0.0, 1.0, 0.0], {"text": "world"})
|
||||
results = vm.search([1.0, 0.0, 0.0], top_k=1)
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["score"] > 0.99
|
||||
|
||||
|
||||
def test_cosine_similarity():
|
||||
assert abs(VectorMemory._cosine_similarity([1, 0], [1, 0]) - 1.0) < 0.001
|
||||
assert abs(VectorMemory._cosine_similarity([1, 0], [0, 1])) < 0.001
|
||||
assert abs(VectorMemory._cosine_similarity([1, 1], [1, 1]) - 1.0) < 0.001
|
||||
|
||||
|
||||
def test_zero_vector():
|
||||
assert VectorMemory._cosine_similarity([0, 0], [1, 0]) == 0.0
|
||||
|
||||
|
||||
def test_delete():
|
||||
vm = VectorMemory()
|
||||
vm.add("doc1", [1.0, 0.0])
|
||||
assert vm.count() == 1
|
||||
assert vm.delete("doc1") is True
|
||||
assert vm.count() == 0
|
||||
|
||||
|
||||
def test_max_entries():
|
||||
vm = VectorMemory(max_entries=2)
|
||||
vm.add("a", [1.0])
|
||||
vm.add("b", [2.0])
|
||||
vm.add("c", [3.0])
|
||||
assert vm.count() == 2
|
||||
|
||||
|
||||
def test_search_top_k():
|
||||
vm = VectorMemory()
|
||||
vm.add("a", [1.0, 0.0])
|
||||
vm.add("b", [0.9, 0.1])
|
||||
vm.add("c", [0.0, 1.0])
|
||||
results = vm.search([1.0, 0.0], top_k=2)
|
||||
assert len(results) == 2
|
||||
assert results[0]["id"] == "a"
|
||||
|
||||
|
||||
def test_search_with_metadata():
|
||||
vm = VectorMemory()
|
||||
vm.add("doc", [1.0], {"key": "value"})
|
||||
results = vm.search([1.0])
|
||||
assert results[0]["metadata"]["key"] == "value"
|
||||
Reference in New Issue
Block a user