Initial commit: add .gitignore and README
This commit is contained in:
338
fusionagi/interfaces/voice.py
Normal file
338
fusionagi/interfaces/voice.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Voice interface: speech-to-text, text-to-speech, voice library management."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from fusionagi._time import utc_now_iso
|
||||
from fusionagi.interfaces.base import InterfaceAdapter, InterfaceCapabilities, InterfaceMessage, ModalityType
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TTSAdapter(Protocol):
|
||||
"""Protocol for TTS providers (ElevenLabs, Azure, system, etc.). Integrate by injecting an implementation."""
|
||||
|
||||
async def synthesize(self, text: str, voice_id: str | None = None, **kwargs: Any) -> bytes | None:
|
||||
"""Synthesize text to audio. Returns raw audio bytes or None if not available."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class STTAdapter(Protocol):
|
||||
"""Protocol for STT providers (Whisper, Azure, Google, etc.). Integrate by injecting an implementation."""
|
||||
|
||||
async def transcribe(self, audio_data: bytes | None = None, timeout_seconds: float | None = None, **kwargs: Any) -> str | None:
|
||||
"""Transcribe audio to text. Returns transcribed text or None if timeout/unavailable."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProfile(BaseModel):
|
||||
"""Voice profile for text-to-speech synthesis."""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"voice_{uuid.uuid4().hex[:8]}")
|
||||
name: str = Field(description="Human-readable voice name")
|
||||
language: str = Field(default="en-US", description="Language code (e.g., en-US, es-ES)")
|
||||
gender: Literal["male", "female", "neutral"] | None = Field(default=None)
|
||||
age_range: Literal["child", "young_adult", "adult", "senior"] | None = Field(default=None)
|
||||
style: str | None = Field(default=None, description="Voice style (e.g., friendly, professional, calm)")
|
||||
pitch: float = Field(default=1.0, ge=0.5, le=2.0, description="Pitch multiplier")
|
||||
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed multiplier")
|
||||
provider: str = Field(default="system", description="TTS provider (e.g., system, elevenlabs, azure)")
|
||||
provider_voice_id: str | None = Field(default=None, description="Provider-specific voice ID")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str = Field(default_factory=utc_now_iso)
|
||||
|
||||
|
||||
class VoiceLibrary:
|
||||
"""
|
||||
Voice library for managing TTS voice profiles.
|
||||
|
||||
Allows admin to add, configure, and organize voice profiles for different
|
||||
agents, contexts, or user preferences.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._voices: dict[str, VoiceProfile] = {}
|
||||
self._default_voice_id: str | None = None
|
||||
logger.info("VoiceLibrary initialized")
|
||||
|
||||
def add_voice(self, profile: VoiceProfile) -> str:
|
||||
"""
|
||||
Add a voice profile to the library.
|
||||
|
||||
Args:
|
||||
profile: Voice profile to add.
|
||||
|
||||
Returns:
|
||||
Voice ID.
|
||||
"""
|
||||
self._voices[profile.id] = profile
|
||||
if self._default_voice_id is None:
|
||||
self._default_voice_id = profile.id
|
||||
logger.info("Voice added", extra={"voice_id": profile.id, "name": profile.name})
|
||||
return profile.id
|
||||
|
||||
def remove_voice(self, voice_id: str) -> bool:
|
||||
"""
|
||||
Remove a voice profile from the library.
|
||||
|
||||
Args:
|
||||
voice_id: ID of voice to remove.
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found.
|
||||
"""
|
||||
if voice_id in self._voices:
|
||||
del self._voices[voice_id]
|
||||
if self._default_voice_id == voice_id:
|
||||
self._default_voice_id = next(iter(self._voices.keys()), None)
|
||||
logger.info("Voice removed", extra={"voice_id": voice_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_voice(self, voice_id: str) -> VoiceProfile | None:
|
||||
"""Get a voice profile by ID."""
|
||||
return self._voices.get(voice_id)
|
||||
|
||||
def list_voices(
|
||||
self,
|
||||
language: str | None = None,
|
||||
gender: str | None = None,
|
||||
style: str | None = None,
|
||||
) -> list[VoiceProfile]:
|
||||
"""
|
||||
List voice profiles with optional filtering.
|
||||
|
||||
Args:
|
||||
language: Filter by language code.
|
||||
gender: Filter by gender.
|
||||
style: Filter by style.
|
||||
|
||||
Returns:
|
||||
List of matching voice profiles.
|
||||
"""
|
||||
voices = list(self._voices.values())
|
||||
|
||||
if language:
|
||||
voices = [v for v in voices if v.language == language]
|
||||
if gender:
|
||||
voices = [v for v in voices if v.gender == gender]
|
||||
if style:
|
||||
voices = [v for v in voices if v.style == style]
|
||||
|
||||
return voices
|
||||
|
||||
def set_default_voice(self, voice_id: str) -> bool:
|
||||
"""
|
||||
Set the default voice for the library.
|
||||
|
||||
Args:
|
||||
voice_id: ID of voice to set as default.
|
||||
|
||||
Returns:
|
||||
True if set, False if voice not found.
|
||||
"""
|
||||
if voice_id in self._voices:
|
||||
self._default_voice_id = voice_id
|
||||
logger.info("Default voice set", extra={"voice_id": voice_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_default_voice(self) -> VoiceProfile | None:
|
||||
"""Get the default voice profile."""
|
||||
if self._default_voice_id:
|
||||
return self._voices.get(self._default_voice_id)
|
||||
return None
|
||||
|
||||
def update_voice(self, voice_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update a voice profile.
|
||||
|
||||
Args:
|
||||
voice_id: ID of voice to update.
|
||||
updates: Dictionary of fields to update.
|
||||
|
||||
Returns:
|
||||
True if updated, False if not found.
|
||||
"""
|
||||
if voice_id not in self._voices:
|
||||
return False
|
||||
|
||||
voice = self._voices[voice_id]
|
||||
for key, value in updates.items():
|
||||
if hasattr(voice, key):
|
||||
setattr(voice, key, value)
|
||||
|
||||
logger.info("Voice updated", extra={"voice_id": voice_id, "updates": list(updates.keys())})
|
||||
return True
|
||||
|
||||
|
||||
class VoiceInterface(InterfaceAdapter):
|
||||
"""
|
||||
Voice interface adapter for speech interaction.
|
||||
|
||||
Handles:
|
||||
- Speech-to-text (STT) for user input
|
||||
- Text-to-speech (TTS) for system output
|
||||
- Voice activity detection
|
||||
- Noise cancellation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "voice",
|
||||
voice_library: VoiceLibrary | None = None,
|
||||
stt_provider: str = "whisper",
|
||||
tts_provider: str = "system",
|
||||
tts_adapter: TTSAdapter | None = None,
|
||||
stt_adapter: STTAdapter | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize voice interface.
|
||||
|
||||
Args:
|
||||
name: Interface name.
|
||||
voice_library: Voice library for TTS profiles.
|
||||
stt_provider: Speech-to-text provider (whisper, azure, google, etc.).
|
||||
tts_provider: Text-to-speech provider (system, elevenlabs, azure, etc.).
|
||||
tts_adapter: Optional TTS adapter for synthesis (inject to integrate ElevenLabs, Azure, etc.).
|
||||
stt_adapter: Optional STT adapter for transcription (inject to integrate Whisper, Azure, etc.).
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.voice_library = voice_library or VoiceLibrary()
|
||||
self.stt_provider = stt_provider
|
||||
self.tts_provider = tts_provider
|
||||
self._tts_adapter = tts_adapter
|
||||
self._stt_adapter = stt_adapter
|
||||
self._active_voice_id: str | None = None
|
||||
logger.info(
|
||||
"VoiceInterface initialized",
|
||||
extra={"stt_provider": stt_provider, "tts_provider": tts_provider}
|
||||
)
|
||||
|
||||
def capabilities(self) -> InterfaceCapabilities:
|
||||
"""Return voice interface capabilities."""
|
||||
return InterfaceCapabilities(
|
||||
supported_modalities=[ModalityType.VOICE],
|
||||
supports_streaming=True,
|
||||
supports_interruption=True,
|
||||
supports_multimodal=False,
|
||||
latency_ms=200.0, # Typical voice latency
|
||||
max_concurrent_sessions=10,
|
||||
)
|
||||
|
||||
async def send(self, message: InterfaceMessage) -> None:
|
||||
"""
|
||||
Send voice output (text-to-speech).
|
||||
|
||||
Args:
|
||||
message: Message with text content to synthesize.
|
||||
"""
|
||||
if not self.validate_message(message):
|
||||
logger.warning("Invalid message for voice interface", extra={"modality": message.modality})
|
||||
return
|
||||
|
||||
# Get voice profile
|
||||
voice_id = message.metadata.get("voice_id", self._active_voice_id)
|
||||
voice = None
|
||||
if voice_id:
|
||||
voice = self.voice_library.get_voice(voice_id)
|
||||
if not voice:
|
||||
voice = self.voice_library.get_default_voice()
|
||||
|
||||
text = message.content if isinstance(message.content, str) else str(message.content)
|
||||
voice_id = voice.id if voice else None
|
||||
if self._tts_adapter is not None:
|
||||
try:
|
||||
audio_data = await self._tts_adapter.synthesize(text, voice_id=voice_id)
|
||||
if audio_data:
|
||||
logger.info(
|
||||
"TTS synthesis (adapter)",
|
||||
extra={"text_length": len(text), "voice_id": voice_id, "bytes": len(audio_data)},
|
||||
)
|
||||
# Inject: await self._play_audio(audio_data)
|
||||
except Exception as e:
|
||||
logger.exception("TTS adapter failed", extra={"error": str(e)})
|
||||
else:
|
||||
logger.info(
|
||||
"TTS synthesis (stub; inject tts_adapter for ElevenLabs, Azure, etc.)",
|
||||
extra={"text_length": len(text), "voice_id": voice_id, "provider": self.tts_provider},
|
||||
)
|
||||
|
||||
async def receive(self, timeout_seconds: float | None = None) -> InterfaceMessage | None:
|
||||
"""
|
||||
Receive voice input (speech-to-text).
|
||||
|
||||
Args:
|
||||
timeout_seconds: Optional timeout for listening.
|
||||
|
||||
Returns:
|
||||
Message with transcribed text or None if timeout.
|
||||
"""
|
||||
logger.info("STT listening", extra={"timeout": timeout_seconds, "provider": self.stt_provider})
|
||||
if self._stt_adapter is not None:
|
||||
try:
|
||||
text = await self._stt_adapter.transcribe(audio_data=None, timeout_seconds=timeout_seconds)
|
||||
if text:
|
||||
return InterfaceMessage(
|
||||
id=f"stt_{uuid.uuid4().hex[:8]}",
|
||||
modality=ModalityType.VOICE,
|
||||
content=text,
|
||||
metadata={"provider": self.stt_provider},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("STT adapter failed", extra={"error": str(e)})
|
||||
return None
|
||||
|
||||
def set_active_voice(self, voice_id: str) -> bool:
|
||||
"""
|
||||
Set the active voice for this interface session.
|
||||
|
||||
Args:
|
||||
voice_id: ID of voice to use.
|
||||
|
||||
Returns:
|
||||
True if voice exists, False otherwise.
|
||||
"""
|
||||
if self.voice_library.get_voice(voice_id):
|
||||
self._active_voice_id = voice_id
|
||||
logger.info("Active voice set", extra={"voice_id": voice_id})
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _synthesize_speech(self, text: str, voice: VoiceProfile | None) -> bytes:
|
||||
"""
|
||||
Synthesize speech from text (to be implemented with actual provider).
|
||||
|
||||
Args:
|
||||
text: Text to synthesize.
|
||||
voice: Voice profile to use.
|
||||
|
||||
Returns:
|
||||
Audio data as bytes.
|
||||
"""
|
||||
# Integrate with TTS provider based on self.tts_provider
|
||||
# - system: Use OS TTS (pyttsx3, etc.)
|
||||
# - elevenlabs: Use ElevenLabs API
|
||||
# - azure: Use Azure Cognitive Services
|
||||
# - google: Use Google Cloud TTS
|
||||
raise NotImplementedError("TTS provider integration required")
|
||||
|
||||
async def _transcribe_speech(self, audio_data: bytes) -> str:
|
||||
"""
|
||||
Transcribe speech to text (to be implemented with actual provider).
|
||||
|
||||
Args:
|
||||
audio_data: Audio data to transcribe.
|
||||
|
||||
Returns:
|
||||
Transcribed text.
|
||||
"""
|
||||
# Integrate with STT provider based on self.stt_provider
|
||||
# - whisper: Use OpenAI Whisper (local or API)
|
||||
# - azure: Use Azure Cognitive Services
|
||||
# - google: Use Google Cloud Speech-to-Text
|
||||
# - deepgram: Use Deepgram API
|
||||
raise NotImplementedError("STT provider integration required")
|
||||
Reference in New Issue
Block a user