Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions examples/audio_moderation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,17 @@ async def main(client: Stream):
# Set up transcription handler
@connection.on("audio")
async def on_audio(pcm: PcmData, user):
# Process audio through Deepgram STT
await stt.process_audio(pcm, user)
# Process audio through Deepgram STT with user metadata
user_metadata = {"user": user} if user else None
await stt.process_audio(pcm, user_metadata)

@stt.on("transcript")
async def on_transcript(event):
timestamp = time.strftime("%H:%M:%S")
user_info = user.name if user and hasattr(user, "name") else "unknown"
user_info = "unknown"
if event.user_metadata and "user" in event.user_metadata:
user = event.user_metadata["user"]
user_info = user.name if hasattr(user, "name") else str(user)
print(f"[{timestamp}] {user_info}: {event.text}")
if hasattr(event, 'confidence') and event.confidence:
print(f" └─ confidence: {event.confidence:.2%}")
Expand Down
10 changes: 7 additions & 3 deletions examples/fal_stt_translate/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,17 @@ async def on_speech_detected(pcm: PcmData, user):
print(
f"{time.time()} Speech detected from user: {user} duration {pcm.duration}"
)
# Process audio through FAL.ai STT
await stt.process_audio(pcm, user)
# Process audio through FAL.ai STT with user metadata
user_metadata = {"user": user} if user else None
await stt.process_audio(pcm, user_metadata)

@stt.on("transcript")
async def on_transcript(event):
timestamp = time.strftime("%H:%M:%S")
user_info = user.name if user and hasattr(user, "name") else "unknown"
user_info = "unknown"
if event.user_metadata and "user" in event.user_metadata:
user = event.user_metadata["user"]
user_info = user.name if hasattr(user, "name") else str(user)
print(f"[{timestamp}] {user_info}: {event.text}")
if hasattr(event, 'confidence') and event.confidence:
print(f" └─ confidence: {event.confidence:.2%}")
Expand Down
11 changes: 8 additions & 3 deletions examples/llm_audio_conversation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import asyncio
import os
import time
from typing import Any
from uuid import uuid4
import webbrowser
from urllib.parse import urlencode
Expand Down Expand Up @@ -147,12 +146,18 @@ async def on_speech_detected(pcm: PcmData, user):
print(
f"{time.time()} Speech detected from user: {user} duration {pcm.duration}"
)
await stt.process_audio(pcm, user)
# Process audio through STT with user metadata
user_metadata = {"user": user} if user else None
await stt.process_audio(pcm, user_metadata)

@stt.on("transcript")
async def on_transcript(event):
user_info = "unknown"
if event.user_metadata and "user" in event.user_metadata:
user = event.user_metadata["user"]
user_info = str(user)
print(
f"{time.time()} got text from user {user}, with metadata {event.to_dict()}"
f"{time.time()} got text from user {user_info}, with metadata {event.to_dict()}"
f"will send the transcript to the LLM: {event.text}"
)

Expand Down
10 changes: 8 additions & 2 deletions examples/mcp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,17 @@ async def run_bot(call: Call, bot_user_id: str):
@connection.on("audio")
async def on_audio(pcm: PcmData, user):
"""Pipe raw PCM into the STT engine."""
await stt.process_audio(pcm, user)
# Process audio through STT with user metadata
user_metadata = {"user": user} if user else None
await stt.process_audio(pcm, user_metadata)

@stt.on("transcript")
async def on_transcript(event):
logging.info("🗣️ %s: %s", user or "unknown", event.text)
user_info = "unknown"
if event.user_metadata and "user" in event.user_metadata:
user = event.user_metadata["user"]
user_info = str(user)
logging.info("🗣️ %s: %s", user_info, event.text)

# Ask the LLM; it may decide to call an MCP tool.
answer = await chat_with_tools(event.text, mcp_client)
Expand Down
19 changes: 12 additions & 7 deletions examples/stt_deepgram_transcription/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import uuid
import webbrowser
from urllib.parse import urlencode
from typing import Any

from dotenv import load_dotenv

Expand Down Expand Up @@ -125,13 +124,18 @@ async def main():
# Set up transcription handlers
@connection.on("audio")
async def on_audio(pcm: PcmData, user):
# Process audio through Deepgram STT
await stt.process_audio(pcm, user)
# Process audio through Deepgram STT with user metadata
user_metadata = {"user": user} if user else None
await stt.process_audio(pcm, user_metadata)

@stt.on("transcript")
async def on_transcript(event):
timestamp = time.strftime("%H:%M:%S")
print(f"[{timestamp}] {event.user_metadata.name}: {event.text}")
user_info = "unknown"
if event.user_metadata and "user" in event.user_metadata:
user = event.user_metadata["user"]
user_info = user.name if hasattr(user, "name") else str(user)
print(f"[{timestamp}] {user_info}: {event.text}")
if hasattr(event, 'confidence') and event.confidence:
print(f" └─ confidence: {event.confidence:.2%}")
if hasattr(event, 'processing_time_ms') and event.processing_time_ms:
Expand All @@ -140,9 +144,10 @@ async def on_transcript(event):
@stt.on("partial_transcript")
async def on_partial_transcript(event):
if event.text.strip(): # Only show non-empty partial transcripts
user_info = (
user.name if user and hasattr(user, "name") else "unknown"
)
user_info = "unknown"
if event.user_metadata and "user" in event.user_metadata:
user = event.user_metadata["user"]
user_info = user.name if hasattr(user, "name") else str(user)
print(
f" {user_info} (partial): {event.text}", end="\r"
) # Overwrite line
Expand Down
11 changes: 8 additions & 3 deletions examples/stt_moonshine_transcription/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,18 @@ async def _on_speech_detected(pcm: PcmData, user):
print(
f"🎤 Speech detected from user: {user.name}, duration: {pcm.duration:.2f}s"
)
await stt.process_audio(pcm, user)
# Process audio through STT with user metadata
user_metadata = {"user": user} if user else None
await stt.process_audio(pcm, user_metadata)

@stt.on("transcript")
async def _on_transcript(event):
ts = time.strftime("%H:%M:%S")
who = user if user else "unknown"
print(f"[{ts}] {who}: {event.text}")
user_info = "unknown"
if event.user_metadata and "user" in event.user_metadata:
user = event.user_metadata["user"]
user_info = str(user)
print(f"[{ts}] {user_info}: {event.text}")
if hasattr(event, 'confidence') and event.confidence:
print(f" └─ confidence: {event.confidence:.2%}")
if hasattr(event, 'processing_time_ms') and event.processing_time_ms:
Expand Down
10 changes: 7 additions & 3 deletions examples/video_moderation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,17 @@ async def main(client: Stream):
# Set up transcription handler
@connection.on("audio")
async def on_audio(pcm: PcmData, user):
# Process audio through Deepgram STT
await stt.process_audio(pcm, user)
# Process audio through Deepgram STT with user metadata
user_metadata = {"user": user} if user else None
await stt.process_audio(pcm, user_metadata)

@stt.on("transcript")
async def on_transcript(event):
timestamp = time.strftime("%H:%M:%S")
user_info = user.name if user and hasattr(user, "name") else "unknown"
user_info = "unknown"
if event.user_metadata and "user" in event.user_metadata:
user = event.user_metadata["user"]
user_info = user.name if hasattr(user, "name") else str(user)
print(f"[{timestamp}] {user_info}: {event.text}")
if hasattr(event, 'confidence') and event.confidence:
print(f" └─ confidence: {event.confidence:.2%}")
Expand Down
8 changes: 4 additions & 4 deletions getstream/plugins/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def to_dict(self) -> Dict[str, Any]:
result = {}

import dataclasses
for field in dataclasses.fields(self):
field_value = getattr(self, field.name)
for field_info in dataclasses.fields(self):
field_value = getattr(self, field_info.name)
if isinstance(field_value, (datetime, Enum)):
result[field.name] = field_value.value if isinstance(field_value, Enum) else str(field_value)
result[field_info.name] = field_value.value if isinstance(field_value, Enum) else str(field_value)
else:
result[field.name] = field_value
result[field_info.name] = field_value
return result


Expand Down
2 changes: 2 additions & 0 deletions getstream/plugins/common/sts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class STS(AsyncIOEventEmitter, abc.ABC):
def __init__(
self,
*,
provider_name: Optional[str] = None,
model: Optional[str] = None,
instructions: Optional[str] = None,
temperature: Optional[float] = None,
Expand All @@ -53,6 +54,7 @@ def __init__(
to their own session/config structures. They are not enforced here.

Args:
provider_name: Optional provider name override. Defaults to class name.
model: Model ID to use when connecting.
instructions: Optional system instructions passed to the session.
temperature: Optional temperature passed to the session.
Expand Down
3 changes: 1 addition & 2 deletions getstream/plugins/common/tests/test_events.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import json
from datetime import datetime
from typing import Dict, Any

from getstream.plugins.common.events import (
# Base events
Expand Down Expand Up @@ -1097,7 +1096,7 @@ def test_round_trip_serialization_all_event_types(self):
deserialized = deserialize_event(serialized)

# Verify type and basic properties
assert type(deserialized) == type(original_event)
assert isinstance(deserialized, type(original_event))
assert deserialized.event_type == original_event.event_type

# Verify specific properties based on event type
Expand Down
26 changes: 16 additions & 10 deletions getstream/plugins/elevenlabs/tests/test_tts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from unittest.mock import patch, MagicMock

from getstream.plugins.elevenlabs.tts import ElevenLabsTTS
from getstream.video.rtc.audio_track import AudioStreamTrack
Expand All @@ -27,12 +27,16 @@ def __init__(self, api_key=None):
# Create a mock audio stream that returns a few chunks of audio
mock_audio = [b"\x00\x00" * 1000, b"\x00\x00" * 1000]

# Mock the async stream method to return an iterable directly
self.text_to_speech.stream = AsyncMock(return_value=mock_audio)
# Mock the async stream method to return an async generator
async def mock_stream(*args, **kwargs):
for chunk in mock_audio:
yield chunk

self.text_to_speech.stream = mock_stream


@pytest.mark.asyncio
@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient)
@patch("getstream.plugins.elevenlabs.tts.tts.AsyncElevenLabs", MockAsyncElevenLabsClient)
async def test_elevenlabs_tts_initialization():
"""Test that the ElevenLabs TTS initializes correctly with explicit API key."""
tts = ElevenLabsTTS(api_key="test-api-key")
Expand All @@ -41,7 +45,7 @@ async def test_elevenlabs_tts_initialization():


@pytest.mark.asyncio
@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient)
@patch("getstream.plugins.elevenlabs.tts.tts.AsyncElevenLabs", MockAsyncElevenLabsClient)
@patch.dict(os.environ, {"ELEVENLABS_API_KEY": "env-var-api-key"})
async def test_elevenlabs_tts_initialization_with_env_var():
"""ElevenLabsTTS should use ELEVENLABS_API_KEY when no key argument is given."""
Expand All @@ -52,7 +56,7 @@ async def test_elevenlabs_tts_initialization_with_env_var():


@pytest.mark.asyncio
@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient)
@patch("getstream.plugins.elevenlabs.tts.tts.AsyncElevenLabs", MockAsyncElevenLabsClient)
async def test_elevenlabs_tts_synthesize():
"""Test that synthesize returns an audio stream."""
tts = ElevenLabsTTS(api_key="test-api-key")
Expand All @@ -61,17 +65,19 @@ async def test_elevenlabs_tts_synthesize():
text = "Hello, world!"
audio_stream = await tts.stream_audio(text)

# Check that it's an iterator
assert hasattr(audio_stream, "__iter__")
# Check that it's an async iterator
assert hasattr(audio_stream, "__aiter__")

# Check that we can get chunks from it
chunks = list(audio_stream)
chunks = []
async for chunk in audio_stream:
chunks.append(chunk)
assert len(chunks) > 0
assert all(isinstance(chunk, bytes) for chunk in chunks)


@pytest.mark.asyncio
@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient)
@patch("getstream.plugins.elevenlabs.tts.tts.AsyncElevenLabs", MockAsyncElevenLabsClient)
async def test_elevenlabs_tts_send():
"""Test that send writes audio to the track and emits events."""
tts = ElevenLabsTTS(api_key="test-api-key")
Expand Down
2 changes: 1 addition & 1 deletion getstream/plugins/kokoro/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.voice = voice
self.speed = speed
self.sample_rate = sample_rate
self.client = client if client is not None else self._pipeline,
self.client = client if client is not None else self._pipeline

def set_output_track(self, track: AudioStreamTrack) -> None: # noqa: D401
if track.framerate != self.sample_rate:
Expand Down