diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 8a4a853..1f4a614 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -38,6 +38,7 @@ jobs:
python-version: ${{ matrix.py }}
- name: Setup test suite
run: |
+ sudo apt-get update && sudo apt-get install -y portaudio19-dev
python_version="${{ matrix.py }}"
python_version="${python_version/./}"
tox -f "py$python_version" -vvvv --notest
diff --git a/README.md b/README.md
index a9ddff9..69158f3 100644
--- a/README.md
+++ b/README.md
@@ -18,14 +18,23 @@ With a single API call, get access to AI models built on the latest AI breakthro
# Overview
+- [AssemblyAI's Python SDK](#assemblyais-python-sdk)
+- [Overview](#overview)
- [Documentation](#documentation)
-- [Installation](#installation)
-- [Example](#examples)
- - [Core Examples](#core-examples)
- - [LeMUR Examples](#lemur-examples)
- - [Audio Intelligence Examples](#audio-intelligence-examples)
-- [Playgrounds](#playgrounds)
-- [Advanced](#advanced-todo)
+- [Quick Start](#quick-start)
+ - [Installation](#installation)
+ - [Examples](#examples)
+ - [**Core Examples**](#core-examples)
+ - [**LeMUR Examples**](#lemur-examples)
+ - [**Audio Intelligence Examples**](#audio-intelligence-examples)
+ - [**Real-Time Examples**](#real-time-examples)
+ - [Playgrounds](#playgrounds)
+- [Advanced](#advanced)
+ - [How the SDK handles Default Configurations](#how-the-sdk-handles-default-configurations)
+ - [Defining Defaults](#defining-defaults)
+ - [Overriding Defaults](#overriding-defaults)
+ - [Synchronous vs Asynchronous](#synchronous-vs-asynchronous)
+ - [Polling Intervals](#polling-intervals)
# Documentation
@@ -470,6 +479,113 @@ for result in transcript.auto_highlights.results:
---
+### **Real-Time Examples**
+
+[Read more about our Real-Time service.](https://www.assemblyai.com/docs/Guides/real-time_streaming_transcription)
+
+
+ Stream your Microphone in Real-Time
+
+```python
+import assemblyai as aai
+
+def on_open(session_opened: aai.RealtimeSessionOpened):
+ "This function is called when the connection has been established."
+
+ print("Session ID:", session_opened.session_id)
+
+def on_data(transcript: aai.RealtimeTranscript):
+ "This function is called when a new transcript has been received."
+
+ if not transcript.text:
+ return
+
+ if isinstance(transcript, aai.RealtimeFinalTranscript):
+ print(transcript.text, end="\r\n")
+ else:
+ print(transcript.text, end="\r")
+
+def on_error(error: aai.RealtimeError):
+ "This function is called when the connection has been closed."
+
+ print("An error occured:", error)
+
+def on_close():
+ "This function is called when the connection has been closed."
+
+ print("Closing Session")
+
+
+# Create the Real-Time transcriber
+transcriber = aai.RealtimeTranscriber(
+ on_data=on_data,
+ on_error=on_error,
+ sample_rate=44_100,
+ on_open=on_open, # optional
+ on_close=on_close, # optional
+)
+
+
+# Open a microphone stream
+microphone_stream = aai.extras.MicrophoneStream()
+
+# Press CTRL+C to abort
+transcriber.stream(microphone_stream)
+
+transcriber.close()
+```
+
+
+
+
+ Transcribe a Local Audio File in Real-Time
+
+```python
+import assemblyai as aai
+
+
+def on_data(transcript: aai.RealtimeTranscript):
+ "This function is called when a new transcript has been received."
+
+ if not transcript.text:
+ return
+
+ if isinstance(transcript, aai.RealtimeFinalTranscript):
+ print(transcript.text, end="\r\n")
+ else:
+ print(transcript.text, end="\r")
+
+def on_error(error: aai.RealtimeError):
+ "This function is called when the connection has been closed."
+
+ print("An error occured:", error)
+
+
+# Create the Real-Time transcriber
+transcriber = aai.RealtimeTranscriber(
+ on_data=on_data,
+ on_error=on_error,
+ sample_rate=44_100,
+ on_open=on_open, # optional
+ on_close=on_close, # optional
+)
+
+
+# Only WAV/PCM16 single channel supported for now
+file_stream = aai.extras.stream_file(
+ filepath="audio.wav",
+ sample_rate=44_100,
+)
+
+transcriber.stream(file_stream)
+
+transcriber.close()
+```
+
+
+
+---
+
## Playgrounds
Visit one of our Playgrounds:
diff --git a/assemblyai/__init__.py b/assemblyai/__init__.py
index 530d65d..ae25dda 100644
--- a/assemblyai/__init__.py
+++ b/assemblyai/__init__.py
@@ -1,6 +1,7 @@
+from . import extras
from .client import Client
from .lemur import Lemur
-from .transcriber import Transcriber, Transcript, TranscriptGroup
+from .transcriber import RealtimeTranscriber, Transcriber, Transcript, TranscriptGroup
from .types import (
AssemblyAIError,
AutohighlightResponse,
@@ -24,6 +25,12 @@
PIIRedactionPolicy,
PIISubstitutionPolicy,
RawTranscriptionConfig,
+ RealtimeError,
+ RealtimeFinalTranscript,
+ RealtimePartialTranscript,
+ RealtimeSessionOpened,
+ RealtimeTranscript,
+ RealtimeWord,
Sentence,
Sentiment,
SentimentType,
@@ -93,6 +100,14 @@
"Word",
"WordBoost",
"WordSearchMatch",
+ "RealtimeError",
+ "RealtimeFinalTranscript",
+ "RealtimePartialTranscript",
+ "RealtimeSessionOpened",
+ "RealtimeTranscript",
+ "RealtimeWord",
# package globals
"settings",
+ # packages
+ "extras",
]
diff --git a/assemblyai/extras.py b/assemblyai/extras.py
new file mode 100644
index 0000000..9cf96c0
--- /dev/null
+++ b/assemblyai/extras.py
@@ -0,0 +1,102 @@
+import time
+from typing import Generator
+
+try:
+ import pyaudio
+except ImportError:
+ raise ImportError(
+ "You must install the extras for this SDK to use this feature. "
+ "Run `pip install assemblyai[extras]` to install the extras. "
+ "Make sure to install `apt install portaudio19-dev` (Debian/Ubuntu) or "
+ "`brew install portaudio` (MacOS) before installing the extras."
+ )
+
+
+class MicrophoneStream:
+ def __init__(
+ self,
+ sample_rate: int = 44_100,
+ ):
+ """
+ Creates a stream of audio from the microphone.
+
+ Args:
+ chunk_size: The size of each chunk of audio to read from the microphone.
+ channels: The number of channels to record audio from.
+ sample_rate: The sample rate to record audio at.
+ """
+
+ self._pyaudio = pyaudio.PyAudio()
+ self.sample_rate = sample_rate
+
+ self._chunk_size = int(self.sample_rate * 0.1)
+ self._stream = self._pyaudio.open(
+ format=pyaudio.paInt16,
+ channels=1,
+ rate=sample_rate,
+ input=True,
+ frames_per_buffer=self._chunk_size,
+ )
+
+ self._open = True
+
+ def __iter__(self):
+ """
+ Returns the iterator object.
+ """
+
+ return self
+
+ def __next__(self):
+ """
+ Reads a chunk of audio from the microphone.
+ """
+ if not self._open:
+ raise StopIteration
+
+ try:
+ return self._stream.read(self._chunk_size)
+ except KeyboardInterrupt:
+ raise StopIteration
+
+ def close(self):
+ """
+ Closes the stream.
+ """
+
+ self._open = False
+
+ if self._stream.is_active():
+ self._stream.stop_stream()
+
+ self._stream.close()
+ self._pyaudio.terminate()
+
+
+def stream_file(
+ filepath: str,
+ sample_rate: int,
+) -> Generator[bytes, None, None]:
+ """
+ Mimics a stream of audio data by reading it chunk by chunk from a file.
+
+ NOTE: Only supports WAV/PCM16 files as of now.
+
+ Args:
+ filepath: The path to the file to stream.
+ sample_rate: The sample rate of the audio file.
+
+ Returns: A generator that yields chunks of audio data.
+ """
+
+ with open(filepath, "rb") as f:
+ while True:
+ data = f.read(int(sample_rate * 0.30) * 2)
+ enough_data = ((len(data) / (16 / 8)) / sample_rate) * 1_000
+
+ if not data or enough_data < 300.0:
+ break
+
+ yield data
+
+ time.sleep(0.15)
diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py
index 718fcdc..f06be81 100644
--- a/assemblyai/transcriber.py
+++ b/assemblyai/transcriber.py
@@ -1,12 +1,30 @@
from __future__ import annotations
+import base64
import concurrent.futures
+import json
import os
+import queue
+import threading
import time
-from typing import Dict, Iterator, List, Optional, Union
-from urllib.parse import urlparse
-
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Union,
+)
+from urllib.parse import urlencode, urlparse
+
+import websockets
+import websockets.exceptions
+from httpx import request
from typing_extensions import Self
+from websockets.sync.client import connect as websocket_connect
from . import api
from . import client as _client
@@ -824,3 +842,270 @@ def transcribe_group_async(
config=config,
poll=True,
)
+
+
+class _RealtimeTranscriberImpl:
+ def __init__(
+ self,
+ *,
+ on_data: Callable[[types.RealtimeTranscript], None],
+ on_error: Callable[[types.RealtimeError], None],
+ on_open: Optional[Callable[[types.RealtimeSessionOpened], None]],
+ on_close: Optional[Callable[[], None]],
+ sample_rate: int,
+ word_boost: List[str],
+ client: _client.Client,
+ ) -> None:
+ self._client = client
+ self._websocket: Optional[websockets_client.ClientConnection] = None
+
+ self._on_open = on_open
+ self._on_data = on_data
+ self._on_error = on_error
+ self._on_close = on_close
+ self._sample_rate = sample_rate
+ self._word_boost = word_boost
+
+ self._write_queue: queue.Queue[bytes] = queue.Queue()
+ self._write_thread = threading.Thread(target=self._write)
+ self._read_thread = threading.Thread(target=self._read)
+ self._stop_event = threading.Event()
+
+ def connect(
+ self,
+ timeout: Optional[float],
+ ) -> None:
+ """
+ Connects to the real-time service.
+
+ Args:
+ `timeout`: The maximum time to wait for the connection to be established.
+ """
+
+ params: Dict[str, Any] = {
+ "sample_rate": self._sample_rate,
+ }
+ if self._word_boost:
+ params["word_boost"] = self._word_boost
+
+ websocket_base_url = self._client.settings.base_url.replace("https", "wss")
+
+ try:
+ self._websocket = websocket_connect(
+ f"{websocket_base_url}/realtime/ws?{urlencode(params)}",
+ additional_headers={
+ "Authorization": f"{self._client.settings.api_key}"
+ },
+ open_timeout=timeout,
+ )
+ except Exception as exc:
+ return self._on_error(
+ types.RealtimeError(
+ f"Could not connect to the real-time service: {exc}"
+ )
+ )
+
+ self._read_thread.start()
+ self._write_thread.start()
+
+ def stream(self, data: bytes) -> None:
+ """
+ Streams audio data to the real-time service by putting it into a queue.
+ """
+
+ self._write_queue.put(data)
+
+ def close(self, terminate: bool = False) -> None:
+ """
+ Closes the connection to the real-time service gracefully.
+ """
+
+ with self._write_queue.mutex:
+ self._write_queue.queue.clear()
+
+ if terminate and not self._stop_event.is_set():
+ self._websocket.send(json.dumps({"terminate_session": True}))
+ self._websocket.close()
+
+ self._stop_event.set()
+
+ try:
+ self._read_thread.join()
+ self._write_thread.join()
+ except Exception:
+ pass
+
+ if self._on_close:
+ self._on_close()
+
+ def _read(self) -> None:
+ """
+ Reads messages from the real-time service.
+
+ Must run in a separate thread to avoid blocking the main thread.
+ """
+
+ while not self._stop_event.is_set():
+ try:
+ message = self._websocket.recv(timeout=1)
+ except TimeoutError:
+ continue
+ except websockets.exceptions.ConnectionClosed as exc:
+ return self._handle_error(exc)
+
+ try:
+ message = json.loads(message)
+ except json.JSONDecodeError as exc:
+ self._on_error(
+ types.RealtimeError(
+ f"Could not decode message: {exc}",
+ )
+ )
+ continue
+
+ self._handle_message(message)
+
+ def _write(self) -> None:
+ """
+ Writes messages to the real-time service.
+
+ Must run in a separate thread to avoid blocking the main thread.
+ """
+
+ while not self._stop_event.is_set():
+ try:
+ data = self._write_queue.get(timeout=1)
+ except queue.Empty:
+ continue
+
+ try:
+ self._websocket.send(self._encode_data(data))
+ except websockets.exceptions.ConnectionClosed as exc:
+ return self._handle_error(exc)
+
+ def _encode_data(self, data: bytes) -> str:
+ """
+ Encodes the given audio chunk as a base64 string.
+
+ This is a helper method for `_write`.
+ """
+
+ return json.dumps(
+ {
+ "audio_data": base64.b64encode(data).decode("utf-8"),
+ }
+ )
+
+ def _handle_message(
+ self,
+ message: Dict[str, Any],
+ ) -> None:
+ """
+ Handles a message received from the real-time service by calling the appropriate
+ callback.
+
+ Args:
+ `message`: The message to handle.
+ """
+ if "message_type" in message:
+ if message["message_type"] == types.RealtimeMessageTypes.partial_transcript:
+ self._on_data(types.RealtimePartialTranscript(**message))
+ elif message["message_type"] == types.RealtimeMessageTypes.final_transcript:
+ self._on_data(types.RealtimeFinalTranscript(**message))
+ elif (
+ message["message_type"] == types.RealtimeMessageTypes.session_begins
+ and self._on_open
+ ):
+ self._on_open(types.RealtimeSessionOpened(**message))
+ elif "error" in message:
+ self._on_error(types.RealtimeError(message["error"]))
+
+ def _handle_error(self, error: websockets.exceptions.ConnectionClosed) -> None:
+ """
+ Handles a WebSocket error by calling the appropriate callback.
+ """
+ if error.code >= 4000 and error.code <= 4999:
+ error_message = types.RealtimeErrorMapping[error.code]
+ else:
+ error_message = error.reason
+
+ self._on_error(types.RealtimeError(error_message))
+ self.close()
+
+
+class RealtimeTranscriber:
+ def __init__(
+ self,
+ *,
+ on_data: Callable[[types.RealtimeTranscript], None],
+ on_error: Callable[[types.RealtimeError], None],
+ on_open: Optional[Callable[[types.RealtimeSessionOpened], None]] = None,
+ on_close: Optional[Callable[[], None]] = None,
+ sample_rate: int,
+ word_boost: List[str] = [],
+ client: Optional[_client.Client] = None,
+ ) -> None:
+ """
+ Creates a new real-time transcriber.
+
+ Args:
+ `on_data`: The callback to call when a new transcript is received.
+ `on_error`: The callback to call when an error occurs.
+ `on_open`: (Optional) The callback to call when the connection to the real-time service
+ `on_close`: (Optional) The callback to call when the connection to the real-time service
+ `sample_rate`: The sample rate of the audio data.
+ `word_boost`: (Optional) A list of words to boost the confidence of.
+ `client`: (Optional) The client to use for the real-time service.
+ """
+
+ self._client = client or _client.Client.get_default()
+
+ self._impl = _RealtimeTranscriberImpl(
+ on_open=on_open,
+ on_data=on_data,
+ on_error=on_error,
+ on_close=on_close,
+ sample_rate=sample_rate,
+ word_boost=word_boost,
+ client=self._client,
+ )
+
+ def connect(
+ self,
+ timeout: Optional[float] = 10.0,
+ ) -> None:
+ """
+ Connects to the real-time service.
+
+ Args:
+ `timeout`: The timeout in seconds to wait for the connection to be established.
+ A `timeout` of `None` means no timeout.
+ """
+
+ self._impl.connect(timeout=timeout)
+
+ def stream(
+ self,
+ data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]],
+ ) -> None:
+ """
+ Streams raw audio data to the real-time service.
+
+ Args:
+ `data`: Raw audio data in `bytes` or a generator/iterable of `bytes`.
+
+ Note: Make sure that `data` matches the `sample_rate` that was given in the constructor.
+ """
+ if isinstance(data, bytes):
+ self._impl.stream(data)
+ return
+
+ for chunk in data:
+ self._impl.stream(chunk)
+
+ def close(self) -> None:
+ """
+ Closes the connection to the real-time service.
+ """
+
+ self._impl.close(terminate=True)
diff --git a/assemblyai/types.py b/assemblyai/types.py
index fcb2692..bfce3fb 100644
--- a/assemblyai/types.py
+++ b/assemblyai/types.py
@@ -1,7 +1,8 @@
+from datetime import datetime
from enum import Enum
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
-from pydantic import BaseModel, BaseSettings, Extra, Field
+from pydantic import UUID4, BaseModel, BaseSettings, Extra, Field
from typing_extensions import Self
@@ -1532,3 +1533,136 @@ class LemurCoachRequest(BaseModel):
class LemurCoachResponse(BaseModel):
response: str
model: LemurModel = LemurModel.default
+
+
+class RealtimeMessageTypes(str, Enum):
+ """
+ The type of message received from the real-time API
+ """
+
+ partial_transcript = "PartialTranscript"
+ final_transcript = "FinalTranscript"
+ session_begins = "SessionBegins"
+
+
+class RealtimeSessionOpened(BaseModel):
+ """
+ Once a real-time session is opened, the client will receive this message
+ """
+
+ message_type: Literal[
+ RealtimeMessageTypes.session_begins
+ ] = RealtimeMessageTypes.session_begins
+
+ session_id: UUID4
+ "Unique identifier for the established session."
+
+ expires_at: datetime
+ "Timestamp when this session will expire."
+
+
+class RealtimeWord(BaseModel):
+ """
+ A word in a real-time transcript
+ """
+
+ start: int
+ "Start time of word relative to session start, in milliseconds"
+
+ end: int
+ "End time of word relative to session start, in milliseconds"
+
+ confidence: float
+ "The confidence score of the word, between 0 and 1"
+
+ text: str
+ "The word itself"
+
+
+class RealtimeTranscript(BaseModel):
+ """
+ Base class for real-time transcript messages.
+ """
+
+ message_type: Literal[
+ RealtimeMessageTypes.partial_transcript, RealtimeMessageTypes.final_transcript
+ ]
+ "Describes the type of message"
+
+ audio_start: int
+ "Start time of audio sample relative to session start, in milliseconds"
+
+ audio_end: int
+ "End time of audio sample relative to session start, in milliseconds"
+
+ confidence: float
+ "The confidence score of the entire transcription, between 0 and 1"
+
+ text: str
+ "The transcript for your audio"
+
+ words: List[Word]
+ """
+ An array of objects, with the information for each word in the transcription text.
+ Will include the `start`/`end` time (in milliseconds) of the word, the `confidence` score of the word,
+ and the `text` (i.e. the word itself)
+ """
+
+ created: datetime
+ "Timestamp when this message was created"
+
+
+class RealtimePartialTranscript(RealtimeTranscript):
+ """
+ As you send audio data to the service, the service will immediately start responding with partial transcripts.
+ """
+
+ message_type: Literal[
+ RealtimeMessageTypes.partial_transcript
+ ] = RealtimeMessageTypes.partial_transcript
+
+
+class RealtimeFinalTranscript(RealtimeTranscript):
+ """
+ After you've received your partial results, our model will continue to analyze incoming audio and,
+ when it detects the end of an "utterance" (usually a pause in speech), it will finalize the results
+ sent to you so far with higher accuracy, as well as add punctuation and casing to the transcription text.
+ """
+
+ message_type: Literal[
+ RealtimeMessageTypes.final_transcript
+ ] = RealtimeMessageTypes.final_transcript
+
+ punctuated: bool
+ "Whether the transcript has been punctuated and cased"
+
+ text_formatted: bool
+ "Whether the transcript has been formatted (e.g. Dollar -> $)"
+
+
+class RealtimeError(AssemblyAIError):
+ """
+ Real-time error message
+ """
+
+
+RealtimeErrorMapping = {
+ 4000: "Sample rate must be a positive integer",
+ 4001: "Not Authorized",
+ 4002: "Insufficient Funds",
+ 4003: """This feature is paid-only and requires you to add a credit card.
+ Please visit https://app.assemblyai.com/ to add a credit card to your account""",
+ 4004: "Session Not Found",
+ 4008: "Session Expired",
+ 4010: "Session Previously Closed",
+ 4029: "Client sent audio too fast",
+ 4030: "Session is handled by another websocket",
+ 4031: "Session idle for too long",
+ 4032: "Audio duration is too short",
+ 4033: "Audio duration is too long",
+ 4100: "Endpoint received invalid JSON",
+ 4101: "Endpoint received a message with an invalid schema",
+ 4102: "This account has exceeded the number of allowed streams",
+ 4103: "The session has been reconnected. This websocket is no longer valid.",
+ 1013: "Temporary server condition forced blocking client's request",
+}
diff --git a/setup.py b/setup.py
index 36e6cf7..225205b 100644
--- a/setup.py
+++ b/setup.py
@@ -7,7 +7,7 @@
setup(
name="assemblyai",
- version="0.9.0",
+ version="0.10.0",
description="AssemblyAI Python SDK",
author="AssemblyAI",
author_email="engineering.sdk@assemblyai.com",
@@ -15,8 +15,12 @@
install_requires=[
"httpx>=0.19.0",
"pydantic>=1.7.0",
- "typing-extensions",
+ "typing-extensions>=3.7,<4.6",
+ "websockets>=11.0",
],
+ extras_require={
+ "extras": ["pyaudio>=0.2.13"],
+ },
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
diff --git a/tests/unit/test_realtime_transcriber.py b/tests/unit/test_realtime_transcriber.py
new file mode 100644
index 0000000..cb6f8a4
--- /dev/null
+++ b/tests/unit/test_realtime_transcriber.py
@@ -0,0 +1,490 @@
+import datetime
+import json
+import uuid
+from typing import Optional
+from unittest.mock import MagicMock
+from urllib.parse import urlencode
+
+import pytest
+import websockets.exceptions
+from faker import Faker
+from pytest_mock import MockFixture
+
+import assemblyai as aai
+
+aai.settings.api_key = "test"
+aai.settings.base_url = "https://api.assemblyai.com/v2"
+
+
+def _disable_rw_threads(mocker: MockFixture):
+ """
+ Disable the read/write threads for the websocket
+ """
+
+ mocker.patch("threading.Thread.start", return_value=None)
+
+
+def test_realtime_connect_has_parameters(mocker: MockFixture):
+ """
+ Test that the connect method has the correct parameters set
+ """
+ actual_url = None
+ actual_additional_headers = None
+ actual_open_timeout = None
+
+ def mocked_websocket_connect(
+ url: str, additional_headers: dict, open_timeout: float
+ ):
+ nonlocal actual_url, actual_additional_headers, actual_open_timeout
+ actual_url = url
+ actual_additional_headers = additional_headers
+ actual_open_timeout = open_timeout
+
+ mocker.patch(
+ "assemblyai.transcriber.websocket_connect",
+ new=mocked_websocket_connect,
+ )
+ _disable_rw_threads(mocker)
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=lambda: None,
+ on_error=lambda error: print(error),
+ sample_rate=44_100,
+ word_boost=["AssemblyAI"],
+ )
+
+ transcriber.connect(timeout=15.0)
+
+ assert (
+ actual_url
+ == f"wss://api.assemblyai.com/v2/realtime/ws?{urlencode(dict(sample_rate=44100, word_boost=['AssemblyAI']))}"
+ )
+ assert actual_additional_headers == {"Authorization": aai.settings.api_key}
+ assert actual_open_timeout == 15.0
+
+
+def test_realtime_connect_succeeds(mocker: MockFixture):
+ """
+ Tests that the `RealtimeTranscriber` successfully connects to the `real-time` service.
+ """
+ on_error_called = False
+
+ def on_error(error: aai.RealtimeError):
+ nonlocal on_error_called
+ on_error_called = True
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=lambda _: None,
+ on_error=on_error,
+ sample_rate=44_100,
+ )
+
+ mocker.patch(
+ "assemblyai.transcriber.websocket_connect",
+ return_value=MagicMock(),
+ )
+
+ # mock the read/write threads
+ _disable_rw_threads(mocker)
+
+ # should pass
+ transcriber.connect()
+
+ # no errors should be called
+ assert not on_error_called
+
+
+def test_realtime_connect_fails(mocker: MockFixture):
+ """
+ Tests that the `RealtimeTranscriber` fails to connect to the `real-time` service.
+ """
+
+ on_error_called = False
+
+ def on_error(error: aai.RealtimeError):
+ nonlocal on_error_called
+ on_error_called = True
+
+ assert isinstance(error, aai.RealtimeError)
+ assert "connection failed" in str(error)
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=lambda _: None,
+ on_error=on_error,
+ sample_rate=44_100,
+ )
+ mocker.patch(
+ "assemblyai.transcriber.websocket_connect",
+ side_effect=Exception("connection failed"),
+ )
+
+ transcriber.connect()
+
+ assert on_error_called
+
+
+def test_realtime__read_succeeds(mocker: MockFixture, faker: Faker):
+ """
+ Tests the `_read` method of the `_RealtimeTranscriberImpl` class.
+ """
+
+ expected_transcripts = [
+ aai.RealtimeFinalTranscript(
+ created=faker.date_time(),
+ text=faker.sentence(),
+ audio_start=0,
+ audio_end=1,
+ confidence=1.0,
+ words=[],
+ punctuated=True,
+ text_formatted=True,
+ )
+ ]
+
+ received_transcripts = []
+
+ def on_data(data: aai.RealtimeTranscript):
+ nonlocal received_transcripts
+ received_transcripts.append(data)
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=on_data,
+ on_error=lambda _: None,
+ sample_rate=44_100,
+ )
+
+ transcriber._impl._websocket = MagicMock()
+ websocket_recv = [
+ json.dumps(msg.dict(), default=str) for msg in expected_transcripts
+ ]
+ transcriber._impl._websocket.recv.side_effect = websocket_recv
+
+ with pytest.raises(StopIteration):
+ transcriber._impl._read()
+
+ assert received_transcripts == expected_transcripts
+
+
+def test_realtime__read_fails(mocker: MockFixture):
+ """
+ Tests the `_read` method of the `_RealtimeTranscriberImpl` class.
+ """
+
+ on_error_called = False
+
+ def on_error(error: aai.RealtimeError):
+ nonlocal on_error_called
+ on_error_called = True
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=lambda _: None,
+ on_error=on_error,
+ sample_rate=44_100,
+ )
+
+ transcriber._impl._websocket = MagicMock()
+ error = websockets.exceptions.ConnectionClosedOK(rcvd=None, sent=None)
+ transcriber._impl._websocket.recv.side_effect = error
+
+ transcriber._impl._read()
+
+ assert on_error_called
+
+
+def test_realtime__write_succeeds(mocker: MockFixture):
+ """
+ Tests the `_write` method of the `_RealtimeTranscriberImpl` class.
+ """
+ audio_chunks = [
+ bytes([1, 2, 3, 4, 5]),
+ bytes([6, 7, 8, 9, 10]),
+ ]
+
+ actual_sent = []
+
+ def mocked_send(data: str):
+ nonlocal actual_sent
+ actual_sent.append(data)
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=lambda _: None,
+ on_error=lambda _: None,
+ sample_rate=44_100,
+ )
+
+ transcriber._impl._websocket = MagicMock()
+ transcriber._impl._websocket.send = mocked_send
+ transcriber._impl._stop_event.is_set = MagicMock(side_effect=[False, False, True])
+
+ transcriber.stream(audio_chunks[0])
+ transcriber.stream(audio_chunks[1])
+
+ transcriber._impl._write()
+
+ # assert that the correct data was sent (base64 encoded)
+ assert len(actual_sent) == 2
+ assert json.loads(actual_sent[0]) == {"audio_data": "AQIDBAU="}
+ assert json.loads(actual_sent[1]) == {"audio_data": "BgcICQo="}
+
+
+def test_realtime__encode_data(mocker: MockFixture):
+ """
+ Tests the `_encode_data` method of the `_RealtimeTranscriberImpl` class.
+ """
+
+ audio_chunks = [
+ bytes([1, 2, 3, 4, 5]),
+ bytes([6, 7, 8, 9, 10]),
+ ]
+
+ expected_encoded_data = [
+ json.dumps({"audio_data": "AQIDBAU="}),
+ json.dumps({"audio_data": "BgcICQo="}),
+ ]
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=lambda _: None,
+ on_error=lambda _: None,
+ sample_rate=44_100,
+ )
+
+ actual_encoded_data = []
+ for chunk in audio_chunks:
+ actual_encoded_data.append(transcriber._impl._encode_data(chunk))
+
+ assert actual_encoded_data == expected_encoded_data
+
+
+def test_realtime__handle_message_session_begins(mocker: MockFixture):
+ """
+ Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
+ with the `SessionBegins` message.
+ """
+
+ test_message = {
+ "message_type": "SessionBegins",
+ "session_id": str(uuid.uuid4()),
+ "expires_at": datetime.datetime.now().isoformat(),
+ }
+
+ on_open_called = False
+
+ def on_open(session_opened: aai.RealtimeSessionOpened):
+ nonlocal on_open_called
+ on_open_called = True
+ assert isinstance(session_opened, aai.RealtimeSessionOpened)
+ assert session_opened.session_id == uuid.UUID(test_message["session_id"])
+ assert session_opened.expires_at.isoformat() == test_message["expires_at"]
+
+ transcriber = aai.RealtimeTranscriber(
+ on_open=on_open,
+ on_data=lambda _: None,
+ on_error=lambda _: None,
+ sample_rate=44_100,
+ )
+
+ transcriber._impl._handle_message(test_message)
+
+ assert on_open_called
+
+
+def test_realtime__handle_message_partial_transcript(mocker: MockFixture):
+ """
+ Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
+ with the `PartialTranscript` message.
+ """
+
+ test_message = {
+ "message_type": "PartialTranscript",
+ "text": "hello world",
+ "audio_start": 0,
+ "audio_end": 1500,
+ "confidence": 0.99,
+ "created": datetime.datetime.now().isoformat(),
+ "words": [
+ {
+ "text": "hello",
+ "start": 0,
+ "end": 500,
+ "confidence": 0.99,
+ },
+ {
+ "text": "world",
+ "start": 500,
+ "end": 1500,
+ "confidence": 0.99,
+ },
+ ],
+ }
+
+ on_data_called = False
+
+ def on_data(data: aai.RealtimePartialTranscript):
+ nonlocal on_data_called
+ on_data_called = True
+ assert isinstance(data, aai.RealtimePartialTranscript)
+ assert data.text == test_message["text"]
+ assert data.audio_start == test_message["audio_start"]
+ assert data.audio_end == test_message["audio_end"]
+ assert data.confidence == test_message["confidence"]
+ assert data.created.isoformat() == test_message["created"]
+ assert data.words == [
+ aai.RealtimeWord(
+ text=test_message["words"][0]["text"],
+ start=test_message["words"][0]["start"],
+ end=test_message["words"][0]["end"],
+ confidence=test_message["words"][0]["confidence"],
+ ),
+ aai.RealtimeWord(
+ text=test_message["words"][1]["text"],
+ start=test_message["words"][1]["start"],
+ end=test_message["words"][1]["end"],
+ confidence=test_message["words"][1]["confidence"],
+ ),
+ ]
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=on_data,
+ on_error=lambda _: None,
+ sample_rate=44_100,
+ )
+
+ transcriber._impl._handle_message(test_message)
+
+ assert on_data_called
+
+
+def test_realtime__handle_message_final_transcript(mocker: MockFixture):
+ """
+ Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
+ with the `FinalTranscript` message.
+ """
+
+ test_message = {
+ "message_type": "FinalTranscript",
+ "text": "Hello, world!",
+ "audio_start": 0,
+ "audio_end": 1500,
+ "confidence": 0.99,
+ "created": datetime.datetime.now().isoformat(),
+ "punctuated": True,
+ "text_formatted": True,
+ "words": [
+ {
+ "text": "Hello,",
+ "start": 0,
+ "end": 500,
+ "confidence": 0.99,
+ },
+ {
+ "text": "world!",
+ "start": 500,
+ "end": 1500,
+ "confidence": 0.99,
+ },
+ ],
+ }
+
+ on_data_called = False
+
+ def on_data(data: aai.RealtimeFinalTranscript):
+ nonlocal on_data_called
+ on_data_called = True
+ assert isinstance(data, aai.RealtimeFinalTranscript)
+ assert data.text == test_message["text"]
+ assert data.audio_start == test_message["audio_start"]
+ assert data.audio_end == test_message["audio_end"]
+ assert data.confidence == test_message["confidence"]
+ assert data.created.isoformat() == test_message["created"]
+ assert data.punctuated == test_message["punctuated"]
+ assert data.text_formatted == test_message["text_formatted"]
+ assert data.words == [
+ aai.RealtimeWord(
+ text=test_message["words"][0]["text"],
+ start=test_message["words"][0]["start"],
+ end=test_message["words"][0]["end"],
+ confidence=test_message["words"][0]["confidence"],
+ ),
+ aai.RealtimeWord(
+ text=test_message["words"][1]["text"],
+ start=test_message["words"][1]["start"],
+ end=test_message["words"][1]["end"],
+ confidence=test_message["words"][1]["confidence"],
+ ),
+ ]
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=on_data,
+ on_error=lambda _: None,
+ sample_rate=44_100,
+ )
+
+ transcriber._impl._handle_message(test_message)
+
+ assert on_data_called
+
+
+def test_realtime__handle_message_error_message(mocker: MockFixture):
+ """
+ Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
+ with the error message.
+ """
+
+ test_message = {
+ "error": "test error",
+ }
+
+ on_error_called = False
+
+ def on_error(error: aai.RealtimeError):
+ nonlocal on_error_called
+ on_error_called = True
+ assert isinstance(error, aai.RealtimeError)
+ assert str(error) == test_message["error"]
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=lambda _: None,
+ on_error=on_error,
+ sample_rate=44_100,
+ )
+
+ transcriber._impl._handle_message(test_message)
+
+ assert on_error_called
+
+
+def test_realtime__handle_message_unknown_message(mocker: MockFixture):
+ """
+ Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
+ with an unknown message.
+ """
+
+ test_message = {
+ "message_type": "Unknown",
+ }
+
+ on_data_called = False
+
+ def on_data(data: aai.RealtimeTranscript):
+ nonlocal on_data_called
+ on_data_called = True
+
+ on_error_called = False
+
+ def on_error(error: aai.RealtimeError):
+ nonlocal on_error_called
+ on_error_called = True
+
+ transcriber = aai.RealtimeTranscriber(
+ on_data=on_data,
+ on_error=on_error,
+ sample_rate=44_100,
+ )
+
+ transcriber._impl._handle_message(test_message)
+
+ assert not on_data_called
+ assert not on_error_called
+
+
+# TODO: create tests for the `RealtimeTranscriber.close` method
diff --git a/tox.ini b/tox.ini
index a5327f1..b10d573 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,9 +1,11 @@
[tox]
-envlist = py{38,39,310,311}-httpx{latest,0.24,0.23,0.22,0.21}-pydantic{latest,1.10,1.9,1.8,1.7}-typing-extensions
+envlist = py{38,39,310,311}-websockets{latest,11.0}-pyaudio{latest,0.2}-httpx{latest,0.24,0.23,0.22,0.21}-pydantic{latest,1.10,1.9,1.8,1.7}-typing-extensions
[testenv]
deps =
# library dependencies
+ websocketslatest: websockets
+ websockets11.0: websockets>=11.0.0,<12.0.0
httpxlatest: httpx
httpx0.24: httpx>=0.24.0,<0.25.0
httpx0.23: httpx>=0.23.0,<0.24.0
@@ -14,11 +16,15 @@ deps =
pydantic1.9: pydantic>=1.9.0,<1.10.0
pydantic1.8: pydantic>=1.8.0,<1.9.0
pydantic1.7: pydantic>=1.7.0,<1.8.0
- typing-extensions: typing-extensions>=3.7
+ typing-extensions: typing-extensions>=3.7,<4.6
+ # extra dependencies
+ pyaudiolatest: pyaudio
+ pyaudio0.2: pyaudio>=0.2.13,<0.3.0
# test dependencies
pytest
pytest-httpx
pytest-xdist
+ pytest-mock
pytest-cov
factory-boy
allowlist_externals = pytest