diff --git a/examples/audio_moderation/main.py b/examples/audio_moderation/main.py index e4141081..bd6681ad 100644 --- a/examples/audio_moderation/main.py +++ b/examples/audio_moderation/main.py @@ -141,13 +141,17 @@ async def main(client: Stream): # Set up transcription handler @connection.on("audio") async def on_audio(pcm: PcmData, user): - # Process audio through Deepgram STT - await stt.process_audio(pcm, user) + # Process audio through Deepgram STT with user metadata + user_metadata = {"user": user} if user else None + await stt.process_audio(pcm, user_metadata) @stt.on("transcript") async def on_transcript(event): timestamp = time.strftime("%H:%M:%S") - user_info = user.name if user and hasattr(user, "name") else "unknown" + user_info = "unknown" + if event.user_metadata and "user" in event.user_metadata: + user = event.user_metadata["user"] + user_info = user.name if hasattr(user, "name") else str(user) print(f"[{timestamp}] {user_info}: {event.text}") if hasattr(event, 'confidence') and event.confidence: print(f" └─ confidence: {event.confidence:.2%}") diff --git a/examples/fal_stt_translate/main.py b/examples/fal_stt_translate/main.py index 15f99845..35cf4a71 100644 --- a/examples/fal_stt_translate/main.py +++ b/examples/fal_stt_translate/main.py @@ -134,13 +134,17 @@ async def on_speech_detected(pcm: PcmData, user): print( f"{time.time()} Speech detected from user: {user} duration {pcm.duration}" ) - # Process audio through FAL.ai STT - await stt.process_audio(pcm, user) + # Process audio through FAL.ai STT with user metadata + user_metadata = {"user": user} if user else None + await stt.process_audio(pcm, user_metadata) @stt.on("transcript") async def on_transcript(event): timestamp = time.strftime("%H:%M:%S") - user_info = user.name if user and hasattr(user, "name") else "unknown" + user_info = "unknown" + if event.user_metadata and "user" in event.user_metadata: + user = event.user_metadata["user"] + user_info = user.name if hasattr(user, "name") else str(user) print(f"[{timestamp}] {user_info}: {event.text}") if hasattr(event, 'confidence') and event.confidence: print(f" └─ confidence: {event.confidence:.2%}") diff --git a/examples/llm_audio_conversation/main.py b/examples/llm_audio_conversation/main.py index a5d07d78..fd56e60c 100644 --- a/examples/llm_audio_conversation/main.py +++ b/examples/llm_audio_conversation/main.py @@ -21,7 +21,6 @@ import asyncio import os import time -from typing import Any from uuid import uuid4 import webbrowser from urllib.parse import urlencode @@ -147,12 +146,18 @@ async def on_speech_detected(pcm: PcmData, user): print( f"{time.time()} Speech detected from user: {user} duration {pcm.duration}" ) - await stt.process_audio(pcm, user) + # Process audio through STT with user metadata + user_metadata = {"user": user} if user else None + await stt.process_audio(pcm, user_metadata) @stt.on("transcript") async def on_transcript(event): + user_info = "unknown" + if event.user_metadata and "user" in event.user_metadata: + user = event.user_metadata["user"] + user_info = str(user) print( - f"{time.time()} got text from user {user}, with metadata {event.to_dict()}" + f"{time.time()} got text from user {user_info}, with metadata {event.to_dict()}" f"will send the transcript to the LLM: {event.text}" ) diff --git a/examples/mcp/main.py b/examples/mcp/main.py index 662ced37..423dae25 100644 --- a/examples/mcp/main.py +++ b/examples/mcp/main.py @@ -82,11 +82,17 @@ async def run_bot(call: Call, bot_user_id: str): @connection.on("audio") async def on_audio(pcm: PcmData, user): """Pipe raw PCM into the STT engine.""" - await stt.process_audio(pcm, user) + # Process audio through STT with user metadata + user_metadata = {"user": user} if user else None + await stt.process_audio(pcm, user_metadata) @stt.on("transcript") async def on_transcript(event): - logging.info("🗣️ %s: %s", user or "unknown", event.text) + user_info = "unknown" + if event.user_metadata and "user" in event.user_metadata: + user = event.user_metadata["user"] + user_info = str(user) + logging.info("🗣️ %s: %s", user_info, event.text) # Ask the LLM; it may decide to call an MCP tool. answer = await chat_with_tools(event.text, mcp_client) diff --git a/examples/stt_deepgram_transcription/main.py b/examples/stt_deepgram_transcription/main.py index 8dd35b19..7ff0bf5b 100644 --- a/examples/stt_deepgram_transcription/main.py +++ b/examples/stt_deepgram_transcription/main.py @@ -22,7 +22,6 @@ import uuid import webbrowser from urllib.parse import urlencode -from typing import Any from dotenv import load_dotenv @@ -125,13 +124,18 @@ async def main(): # Set up transcription handlers @connection.on("audio") async def on_audio(pcm: PcmData, user): - # Process audio through Deepgram STT - await stt.process_audio(pcm, user) + # Process audio through Deepgram STT with user metadata + user_metadata = {"user": user} if user else None + await stt.process_audio(pcm, user_metadata) @stt.on("transcript") async def on_transcript(event): timestamp = time.strftime("%H:%M:%S") - print(f"[{timestamp}] {event.user_metadata.name}: {event.text}") + user_info = "unknown" + if event.user_metadata and "user" in event.user_metadata: + user = event.user_metadata["user"] + user_info = user.name if hasattr(user, "name") else str(user) + print(f"[{timestamp}] {user_info}: {event.text}") if hasattr(event, 'confidence') and event.confidence: print(f" └─ confidence: {event.confidence:.2%}") if hasattr(event, 'processing_time_ms') and event.processing_time_ms: @@ -140,9 +144,10 @@ async def on_transcript(event): @stt.on("partial_transcript") async def on_partial_transcript(event): if event.text.strip(): # Only show non-empty partial transcripts - user_info = ( - user.name if user and hasattr(user, "name") else "unknown" - ) + user_info = "unknown" + if event.user_metadata and "user" in event.user_metadata: + user = event.user_metadata["user"] + user_info = user.name if hasattr(user, "name") else str(user) print( f" {user_info} (partial): {event.text}", end="\r" ) # Overwrite line diff --git a/examples/stt_moonshine_transcription/main.py b/examples/stt_moonshine_transcription/main.py index 3f6f704a..f5b8aa2b 100644 --- a/examples/stt_moonshine_transcription/main.py +++ b/examples/stt_moonshine_transcription/main.py @@ -128,13 +128,18 @@ async def _on_speech_detected(pcm: PcmData, user): print( f"🎤 Speech detected from user: {user.name}, duration: {pcm.duration:.2f}s" ) - await stt.process_audio(pcm, user) + # Process audio through STT with user metadata + user_metadata = {"user": user} if user else None + await stt.process_audio(pcm, user_metadata) @stt.on("transcript") async def _on_transcript(event): ts = time.strftime("%H:%M:%S") - who = user if user else "unknown" - print(f"[{ts}] {who}: {event.text}") + user_info = "unknown" + if event.user_metadata and "user" in event.user_metadata: + user = event.user_metadata["user"] + user_info = str(user) + print(f"[{ts}] {user_info}: {event.text}") if hasattr(event, 'confidence') and event.confidence: print(f" └─ confidence: {event.confidence:.2%}") if hasattr(event, 'processing_time_ms') and event.processing_time_ms: diff --git a/examples/video_moderation/main.py b/examples/video_moderation/main.py index ea7b2a65..a5f2dcdc 100644 --- a/examples/video_moderation/main.py +++ b/examples/video_moderation/main.py @@ -143,13 +143,17 @@ async def main(client: Stream): # Set up transcription handler @connection.on("audio") async def on_audio(pcm: PcmData, user): - # Process audio through Deepgram STT - await stt.process_audio(pcm, user) + # Process audio through Deepgram STT with user metadata + user_metadata = {"user": user} if user else None + await stt.process_audio(pcm, user_metadata) @stt.on("transcript") async def on_transcript(event): timestamp = time.strftime("%H:%M:%S") - user_info = user.name if user and hasattr(user, "name") else "unknown" + user_info = "unknown" + if event.user_metadata and "user" in event.user_metadata: + user = event.user_metadata["user"] + user_info = user.name if hasattr(user, "name") else str(user) print(f"[{timestamp}] {user_info}: {event.text}") if hasattr(event, 'confidence') and event.confidence: print(f" └─ confidence: {event.confidence:.2%}") diff --git a/getstream/plugins/common/events.py b/getstream/plugins/common/events.py index 32938bd6..f9fae00f 100644 --- a/getstream/plugins/common/events.py +++ b/getstream/plugins/common/events.py @@ -93,12 +93,12 @@ def to_dict(self) -> Dict[str, Any]: result = {} import dataclasses - for field in dataclasses.fields(self): - field_value = getattr(self, field.name) + for field_info in dataclasses.fields(self): + field_value = getattr(self, field_info.name) if isinstance(field_value, (datetime, Enum)): - result[field.name] = field_value.value if isinstance(field_value, Enum) else str(field_value) + result[field_info.name] = field_value.value if isinstance(field_value, Enum) else str(field_value) else: - result[field.name] = field_value + result[field_info.name] = field_value return result diff --git a/getstream/plugins/common/sts.py b/getstream/plugins/common/sts.py index 15a0180d..248ab5fc 100644 --- a/getstream/plugins/common/sts.py +++ b/getstream/plugins/common/sts.py @@ -38,6 +38,7 @@ class STS(AsyncIOEventEmitter, abc.ABC): def __init__( self, *, + provider_name: Optional[str] = None, model: Optional[str] = None, instructions: Optional[str] = None, temperature: Optional[float] = None, @@ -53,6 +54,7 @@ def __init__( to their own session/config structures. They are not enforced here. Args: + provider_name: Optional provider name override. Defaults to class name. model: Model ID to use when connecting. instructions: Optional system instructions passed to the session. temperature: Optional temperature passed to the session. diff --git a/getstream/plugins/common/tests/test_events.py b/getstream/plugins/common/tests/test_events.py index 533eaed5..cb015119 100644 --- a/getstream/plugins/common/tests/test_events.py +++ b/getstream/plugins/common/tests/test_events.py @@ -1,7 +1,6 @@ import pytest import json from datetime import datetime -from typing import Dict, Any from getstream.plugins.common.events import ( # Base events @@ -1097,7 +1096,7 @@ def test_round_trip_serialization_all_event_types(self): deserialized = deserialize_event(serialized) # Verify type and basic properties - assert type(deserialized) == type(original_event) + assert isinstance(deserialized, type(original_event)) assert deserialized.event_type == original_event.event_type # Verify specific properties based on event type diff --git a/getstream/plugins/elevenlabs/tests/test_tts.py b/getstream/plugins/elevenlabs/tests/test_tts.py index 02b569d1..68a51e9f 100644 --- a/getstream/plugins/elevenlabs/tests/test_tts.py +++ b/getstream/plugins/elevenlabs/tests/test_tts.py @@ -1,7 +1,7 @@ import os import pytest import asyncio -from unittest.mock import patch, MagicMock, AsyncMock +from unittest.mock import patch, MagicMock from getstream.plugins.elevenlabs.tts import ElevenLabsTTS from getstream.video.rtc.audio_track import AudioStreamTrack @@ -27,12 +27,16 @@ def __init__(self, api_key=None): # Create a mock audio stream that returns a few chunks of audio mock_audio = [b"\x00\x00" * 1000, b"\x00\x00" * 1000] - # Mock the async stream method to return an iterable directly - self.text_to_speech.stream = AsyncMock(return_value=mock_audio) + # Mock the async stream method to return an async generator + async def mock_stream(*args, **kwargs): + for chunk in mock_audio: + yield chunk + + self.text_to_speech.stream = mock_stream @pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) +@patch("getstream.plugins.elevenlabs.tts.tts.AsyncElevenLabs", MockAsyncElevenLabsClient) async def test_elevenlabs_tts_initialization(): """Test that the ElevenLabs TTS initializes correctly with explicit API key.""" tts = ElevenLabsTTS(api_key="test-api-key") @@ -41,7 +45,7 @@ async def test_elevenlabs_tts_initialization(): @pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) +@patch("getstream.plugins.elevenlabs.tts.tts.AsyncElevenLabs", MockAsyncElevenLabsClient) @patch.dict(os.environ, {"ELEVENLABS_API_KEY": "env-var-api-key"}) async def test_elevenlabs_tts_initialization_with_env_var(): """ElevenLabsTTS should use ELEVENLABS_API_KEY when no key argument is given.""" @@ -52,7 +56,7 @@ async def test_elevenlabs_tts_initialization_with_env_var(): @pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) +@patch("getstream.plugins.elevenlabs.tts.tts.AsyncElevenLabs", MockAsyncElevenLabsClient) async def test_elevenlabs_tts_synthesize(): """Test that synthesize returns an audio stream.""" tts = ElevenLabsTTS(api_key="test-api-key") @@ -61,17 +65,19 @@ async def test_elevenlabs_tts_synthesize(): text = "Hello, world!" audio_stream = await tts.stream_audio(text) - # Check that it's an iterator - assert hasattr(audio_stream, "__iter__") + # Check that it's an async iterator + assert hasattr(audio_stream, "__aiter__") # Check that we can get chunks from it - chunks = list(audio_stream) + chunks = [] + async for chunk in audio_stream: + chunks.append(chunk) assert len(chunks) > 0 assert all(isinstance(chunk, bytes) for chunk in chunks) @pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) +@patch("getstream.plugins.elevenlabs.tts.tts.AsyncElevenLabs", MockAsyncElevenLabsClient) async def test_elevenlabs_tts_send(): """Test that send writes audio to the track and emits events.""" tts = ElevenLabsTTS(api_key="test-api-key") diff --git a/getstream/plugins/kokoro/tts/tts.py b/getstream/plugins/kokoro/tts/tts.py index 9d636227..8bc3d60b 100644 --- a/getstream/plugins/kokoro/tts/tts.py +++ b/getstream/plugins/kokoro/tts/tts.py @@ -42,7 +42,7 @@ def __init__( self.voice = voice self.speed = speed self.sample_rate = sample_rate - self.client = client if client is not None else self._pipeline, + self.client = client if client is not None else self._pipeline def set_output_track(self, track: AudioStreamTrack) -> None: # noqa: D401 if track.framerate != self.sample_rate: