import { useState, useCallback, useRef, useEffect } from 'react' import type { WSEvent } from '../types' type WSStatus = 'disconnected' | 'connecting' | 'connected' | 'error' const MAX_RETRIES = 10 const BASE_DELAY = 1000 export interface StreamCallbacks { onToken?: (token: string) => void onHeadUpdate?: (head: string, content: string) => void onComplete?: (response: Record) => void onError?: (error: string) => void } export function useWebSocket(sessionId: string | null) { const [status, setStatus] = useState('disconnected') const [events, setEvents] = useState([]) const [streaming, setStreaming] = useState(false) const wsRef = useRef(null) const retryCount = useRef(0) const retryTimer = useRef | null>(null) const shouldReconnect = useRef(true) const callbacksRef = useRef({}) const connect = useCallback((sid: string) => { if (wsRef.current?.readyState === WebSocket.OPEN) return if (wsRef.current) wsRef.current.close() shouldReconnect.current = true setStatus('connecting') const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:' const ws = new WebSocket(`${protocol}//${window.location.host}/v1/sessions/${sid}/stream`) wsRef.current = ws ws.onopen = () => { setStatus('connected') retryCount.current = 0 } ws.onclose = () => { setStatus('disconnected') setStreaming(false) if (shouldReconnect.current && retryCount.current < MAX_RETRIES) { const delay = BASE_DELAY * Math.pow(2, retryCount.current) + Math.random() * 500 retryCount.current++ retryTimer.current = setTimeout(() => connect(sid), delay) } } ws.onerror = () => { setStatus('error') setStreaming(false) } ws.onmessage = (e) => { try { const event: WSEvent = JSON.parse(e.data) setEvents((prev) => [...prev, event]) // Handle streaming protocol events const cb = callbacksRef.current if (event.type === 'token' && cb.onToken) { cb.onToken(event.data as string) } else if (event.type === 'head_update' && cb.onHeadUpdate) { const d = event.data as Record cb.onHeadUpdate(d.head, d.content) } else if (event.type === 'complete' && cb.onComplete) { setStreaming(false) cb.onComplete(event.data as Record) } else if (event.type === 'error' && cb.onError) { setStreaming(false) cb.onError(event.data as string) } } catch { /* ignore malformed */ } } }, []) const send = useCallback((data: Record) => { if (wsRef.current?.readyState === WebSocket.OPEN) { wsRef.current.send(JSON.stringify(data)) } }, []) const sendPrompt = useCallback((prompt: string, callbacks?: StreamCallbacks) => { if (callbacks) callbacksRef.current = callbacks setStreaming(true) send({ type: 'prompt', prompt }) }, [send]) const disconnect = useCallback(() => { shouldReconnect.current = false if (retryTimer.current) clearTimeout(retryTimer.current) wsRef.current?.close() wsRef.current = null setStatus('disconnected') setStreaming(false) retryCount.current = 0 }, []) const clearEvents = useCallback(() => setEvents([]), []) // SSE fallback: if WebSocket fails repeatedly, use Server-Sent Events const sendPromptSSE = useCallback((sessionId: string, prompt: string, callbacks?: StreamCallbacks) => { if (callbacks) callbacksRef.current = callbacks setStreaming(true) const cb = callbacksRef.current const params = new URLSearchParams({ prompt, session_id: sessionId }) try { const eventSource = new EventSource(`/v1/sessions/stream/sse?${params}`) eventSource.addEventListener('token', (e) => { if (cb.onToken) cb.onToken(e.data) }) eventSource.addEventListener('head_update', (e) => { try { const data = JSON.parse(e.data) if (cb.onHeadUpdate) cb.onHeadUpdate(data.head, data.content) } catch { /* malformed */ } }) eventSource.addEventListener('complete', (e) => { try { const data = JSON.parse(e.data) setStreaming(false) if (cb.onComplete) cb.onComplete(data) } catch { /* malformed */ } eventSource.close() }) eventSource.addEventListener('error', (e) => { setStreaming(false) if (cb.onError && e instanceof MessageEvent) cb.onError(e.data) eventSource.close() }) eventSource.onerror = () => { setStreaming(false) eventSource.close() } } catch { setStreaming(false) if (cb.onError) cb.onError('SSE connection failed') } }, []) // Auto-fallback: after MAX_RETRIES WS failures, switch to SSE const sendWithFallback = useCallback((prompt: string, callbacks?: StreamCallbacks) => { if (wsRef.current?.readyState === WebSocket.OPEN) { sendPrompt(prompt, callbacks) } else if (sessionId && retryCount.current >= MAX_RETRIES) { sendPromptSSE(sessionId, prompt, callbacks) } else { sendPrompt(prompt, callbacks) } }, [sendPrompt, sendPromptSSE, sessionId]) useEffect(() => { return () => { shouldReconnect.current = false if (retryTimer.current) clearTimeout(retryTimer.current) wsRef.current?.close() } }, []) return { status, events, streaming, connect, send, sendPrompt: sendWithFallback, sendPromptSSE, disconnect, clearEvents } }