diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8a4a853..1f4a614 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,6 +38,7 @@ jobs: python-version: ${{ matrix.py }} - name: Setup test suite run: | + sudo apt-get update && sudo apt-get install -y portaudio19-dev python_version="${{ matrix.py }}" python_version="${python_version/./}" tox -f "py$python_version" -vvvv --notest diff --git a/README.md b/README.md index a9ddff9..69158f3 100644 --- a/README.md +++ b/README.md @@ -18,14 +18,23 @@ With a single API call, get access to AI models built on the latest AI breakthro # Overview +- [AssemblyAI's Python SDK](#assemblyais-python-sdk) +- [Overview](#overview) - [Documentation](#documentation) -- [Installation](#installation) -- [Example](#examples) - - [Core Examples](#core-examples) - - [LeMUR Examples](#lemur-examples) - - [Audio Intelligence Examples](#audio-intelligence-examples) -- [Playgrounds](#playgrounds) -- [Advanced](#advanced-todo) +- [Quick Start](#quick-start) + - [Installation](#installation) + - [Examples](#examples) + - [**Core Examples**](#core-examples) + - [**LeMUR Examples**](#lemur-examples) + - [**Audio Intelligence Examples**](#audio-intelligence-examples) + - [**Real-Time Examples**](#real-time-examples) + - [Playgrounds](#playgrounds) +- [Advanced](#advanced) + - [How the SDK handles Default Configurations](#how-the-sdk-handles-default-configurations) + - [Defining Defaults](#defining-defaults) + - [Overriding Defaults](#overriding-defaults) + - [Synchronous vs Asynchronous](#synchronous-vs-asynchronous) + - [Polling Intervals](#polling-intervals) # Documentation @@ -470,6 +479,113 @@ for result in transcript.auto_highlights.results: --- +### **Real-Time Examples** + +[Read more about our Real-Time service.](https://www.assemblyai.com/docs/Guides/real-time_streaming_transcription) + +
+ Stream your Microphone in Real-Time + +```python +import assemblyai as aai + +def on_open(session_opened: aai.RealtimeSessionOpened): + "This function is called when the connection has been established." + + print("Session ID:", session_opened.session_id) + +def on_data(transcript: aai.RealtimeTranscript): + "This function is called when a new transcript has been received." + + if not transcript.text: + return + + if isinstance(transcript, aai.RealtimeFinalTranscript): + print(transcript.text, end="\r\n") + else: + print(transcript.text, end="\r") + +def on_error(error: aai.RealtimeError): + "This function is called when the connection has been closed." + + print("An error occured:", error) + +def on_close(): + "This function is called when the connection has been closed." + + print("Closing Session") + + +# Create the Real-Time transcriber +transcriber = aai.RealtimeTranscriber( + on_data=on_data, + on_error=on_error, + sample_rate=44_100, + on_open=on_open, # optional + on_close=on_close, # optional +) + + +# Open a microphone stream +microphone_stream = aai.extras.MicrophoneStream() + +# Press CTRL+C to abort +transcriber.stream(microphone_stream) + +transcriber.close() +``` + +
+ +
+ Transcribe a Local Audio File in Real-Time + +```python +import assemblyai as aai + + +def on_data(transcript: aai.RealtimeTranscript): + "This function is called when a new transcript has been received." + + if not transcript.text: + return + + if isinstance(transcript, aai.RealtimeFinalTranscript): + print(transcript.text, end="\r\n") + else: + print(transcript.text, end="\r") + +def on_error(error: aai.RealtimeError): + "This function is called when the connection has been closed." + + print("An error occured:", error) + + +# Create the Real-Time transcriber +transcriber = aai.RealtimeTranscriber( + on_data=on_data, + on_error=on_error, + sample_rate=44_100, + on_open=on_open, # optional + on_close=on_close, # optional +) + + +# Only WAV/PCM16 single channel supported for now +file_stream = aai.extras.stream_file( + filepath="audio.wav", + sample_rate=44_100, +) + +transcriber.stream(file_stream) + +transcriber.close() +``` + +
+ +--- + ## Playgrounds Visit one of our Playgrounds: diff --git a/assemblyai/__init__.py b/assemblyai/__init__.py index 530d65d..ae25dda 100644 --- a/assemblyai/__init__.py +++ b/assemblyai/__init__.py @@ -1,6 +1,7 @@ +from . import extras from .client import Client from .lemur import Lemur -from .transcriber import Transcriber, Transcript, TranscriptGroup +from .transcriber import RealtimeTranscriber, Transcriber, Transcript, TranscriptGroup from .types import ( AssemblyAIError, AutohighlightResponse, @@ -24,6 +25,12 @@ PIIRedactionPolicy, PIISubstitutionPolicy, RawTranscriptionConfig, + RealtimeError, + RealtimeFinalTranscript, + RealtimePartialTranscript, + RealtimeSessionOpened, + RealtimeTranscript, + RealtimeWord, Sentence, Sentiment, SentimentType, @@ -93,6 +100,14 @@ "Word", "WordBoost", "WordSearchMatch", + "RealtimeError", + "RealtimeFinalTranscript", + "RealtimePartialTranscript", + "RealtimeSessionOpened", + "RealtimeTranscript", + "RealtimeWord", # package globals "settings", + # packages + "extras", ] diff --git a/assemblyai/extras.py b/assemblyai/extras.py new file mode 100644 index 0000000..9cf96c0 --- /dev/null +++ b/assemblyai/extras.py @@ -0,0 +1,102 @@ +import time +from typing import Generator + +try: + import pyaudio +except ImportError: + raise ImportError( + "You must install the extras for this SDK to use this feature. " + "Run `pip install assemblyai[extras]` to install the extras. " + "Make sure to install `apt install portaudio19-dev` (Debian/Ubuntu) or " + "`brew install portaudio` (MacOS) before installing the extras." + ) + + +class MicrophoneStream: + def __init__( + self, + sample_rate: int = 44_100, + ): + """ + Creates a stream of audio from the microphone. + + Args: + chunk_size: The size of each chunk of audio to read from the microphone. + channels: The number of channels to record audio from. + sample_rate: The sample rate to record audio at. + """ + + self._pyaudio = pyaudio.PyAudio() + self.sample_rate = sample_rate + + self._chunk_size = int(self.sample_rate * 0.1) + self._stream = self._pyaudio.open( + format=pyaudio.paInt16, + channels=1, + rate=sample_rate, + input=True, + frames_per_buffer=self._chunk_size, + ) + + self._open = True + + def __iter__(self): + """ + Returns the iterator object. + """ + + return self + + def __next__(self): + """ + Reads a chunk of audio from the microphone. + """ + if not self._open: + raise StopIteration + + try: + return self._stream.read(self._chunk_size) + except KeyboardInterrupt: + raise StopIteration + + def close(self): + """ + Closes the stream. + """ + + self._open = False + + if self._stream.is_active(): + self._stream.stop_stream() + + self._stream.close() + self._pyaudio.terminate() + + +def stream_file( + filepath: str, + sample_rate: int, +) -> Generator[bytes, None, None]: + """ + Mimics a stream of audio data by reading it chunk by chunk from a file. + + NOTE: Only supports WAV/PCM16 files as of now. + + Args: + filepath: The path to the file to stream. + sample_rate: The sample rate of the audio file. + + Returns: A generator that yields chunks of audio data. + """ + + with open(filepath, "rb") as f: + while True: + data = f.read(int(sample_rate * 0.30) * 2) + enough_data = ((len(data) / (16 / 8)) / sample_rate) * 1_000 + + if not data or enough_data < 300.0: + break + + yield data + + time.sleep(0.15) diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index 718fcdc..f06be81 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -1,12 +1,30 @@ from __future__ import annotations +import base64 import concurrent.futures +import json import os +import queue +import threading import time -from typing import Dict, Iterator, List, Optional, Union -from urllib.parse import urlparse - +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + Optional, + Union, +) +from urllib.parse import urlencode, urlparse + +import websockets +import websockets.exceptions +from httpx import request from typing_extensions import Self +from websockets.sync.client import connect as websocket_connect from . import api from . import client as _client @@ -824,3 +842,270 @@ def transcribe_group_async( config=config, poll=True, ) + + +class _RealtimeTranscriberImpl: + def __init__( + self, + *, + on_data: Callable[[types.RealtimeTranscript], None], + on_error: Callable[[types.RealtimeError], None], + on_open: Optional[Callable[[types.RealtimeSessionOpened], None]], + on_close: Optional[Callable[[], None]], + sample_rate: int, + word_boost: List[str], + client: _client.Client, + ) -> None: + self._client = client + self._websocket: Optional[websockets_client.ClientConnection] = None + + self._on_open = on_open + self._on_data = on_data + self._on_error = on_error + self._on_close = on_close + self._sample_rate = sample_rate + self._word_boost = word_boost + + self._write_queue: queue.Queue[bytes] = queue.Queue() + self._write_thread = threading.Thread(target=self._write) + self._read_thread = threading.Thread(target=self._read) + self._stop_event = threading.Event() + + def connect( + self, + timeout: Optional[float], + ) -> None: + """ + Connects to the real-time service. + + Args: + `timeout`: The maximum time to wait for the connection to be established. + """ + + params: Dict[str, Any] = { + "sample_rate": self._sample_rate, + } + if self._word_boost: + params["word_boost"] = self._word_boost + + websocket_base_url = self._client.settings.base_url.replace("https", "wss") + + try: + self._websocket = websocket_connect( + f"{websocket_base_url}/realtime/ws?{urlencode(params)}", + additional_headers={ + "Authorization": f"{self._client.settings.api_key}" + }, + open_timeout=timeout, + ) + except Exception as exc: + return self._on_error( + types.RealtimeError( + f"Could not connect to the real-time service: {exc}" + ) + ) + + self._read_thread.start() + self._write_thread.start() + + def stream(self, data: bytes) -> None: + """ + Streams audio data to the real-time service by putting it into a queue. + """ + + self._write_queue.put(data) + + def close(self, terminate: bool = False) -> None: + """ + Closes the connection to the real-time service gracefully. + """ + + with self._write_queue.mutex: + self._write_queue.queue.clear() + + if terminate and not self._stop_event.is_set(): + self._websocket.send(json.dumps({"terminate_session": True})) + self._websocket.close() + + self._stop_event.set() + + try: + self._read_thread.join() + self._write_thread.join() + except Exception: + pass + + if self._on_close: + self._on_close() + + def _read(self) -> None: + """ + Reads messages from the real-time service. + + Must run in a separate thread to avoid blocking the main thread. + """ + + while not self._stop_event.is_set(): + try: + message = self._websocket.recv(timeout=1) + except TimeoutError: + continue + except websockets.exceptions.ConnectionClosed as exc: + return self._handle_error(exc) + + try: + message = json.loads(message) + except json.JSONDecodeError as exc: + self._on_error( + types.RealtimeError( + f"Could not decode message: {exc}", + ) + ) + continue + + self._handle_message(message) + + def _write(self) -> None: + """ + Writes messages to the real-time service. + + Must run in a separate thread to avoid blocking the main thread. + """ + + while not self._stop_event.is_set(): + try: + data = self._write_queue.get(timeout=1) + except queue.Empty: + continue + + try: + self._websocket.send(self._encode_data(data)) + except websockets.exceptions.ConnectionClosed as exc: + return self._handle_error(exc) + + def _encode_data(self, data: bytes) -> str: + """ + Encodes the given audio chunk as a base64 string. + + This is a helper method for `_write`. + """ + + return json.dumps( + { + "audio_data": base64.b64encode(data).decode("utf-8"), + } + ) + + def _handle_message( + self, + message: Dict[str, Any], + ) -> None: + """ + Handles a message received from the real-time service by calling the appropriate + callback. + + Args: + `message`: The message to handle. + """ + if "message_type" in message: + if message["message_type"] == types.RealtimeMessageTypes.partial_transcript: + self._on_data(types.RealtimePartialTranscript(**message)) + elif message["message_type"] == types.RealtimeMessageTypes.final_transcript: + self._on_data(types.RealtimeFinalTranscript(**message)) + elif ( + message["message_type"] == types.RealtimeMessageTypes.session_begins + and self._on_open + ): + self._on_open(types.RealtimeSessionOpened(**message)) + elif "error" in message: + self._on_error(types.RealtimeError(message["error"])) + + def _handle_error(self, error: websockets.exceptions.ConnectionClosed) -> None: + """ + Handles a WebSocket error by calling the appropriate callback. + """ + if error.code >= 4000 and error.code <= 4999: + error_message = types.RealtimeErrorMapping[error.code] + else: + error_message = error.reason + + self._on_error(types.RealtimeError(error_message)) + self.close() + + +class RealtimeTranscriber: + def __init__( + self, + *, + on_data: Callable[[types.RealtimeTranscript], None], + on_error: Callable[[types.RealtimeError], None], + on_open: Optional[Callable[[types.RealtimeSessionOpened], None]] = None, + on_close: Optional[Callable[[], None]] = None, + sample_rate: int, + word_boost: List[str] = [], + client: Optional[_client.Client] = None, + ) -> None: + """ + Creates a new real-time transcriber. + + Args: + `on_data`: The callback to call when a new transcript is received. + `on_error`: The callback to call when an error occurs. + `on_open`: (Optional) The callback to call when the connection to the real-time service + `on_close`: (Optional) The callback to call when the connection to the real-time service + `sample_rate`: The sample rate of the audio data. + `word_boost`: (Optional) A list of words to boost the confidence of. + `client`: (Optional) The client to use for the real-time service. + """ + + self._client = client or _client.Client.get_default() + + self._impl = _RealtimeTranscriberImpl( + on_open=on_open, + on_data=on_data, + on_error=on_error, + on_close=on_close, + sample_rate=sample_rate, + word_boost=word_boost, + client=self._client, + ) + + def connect( + self, + timeout: Optional[float] = 10.0, + ) -> None: + """ + Connects to the real-time service. + + Args: + `timeout`: The timeout in seconds to wait for the connection to be established. + A `timeout` of `None` means no timeout. + """ + + self._impl.connect(timeout=timeout) + + def stream( + self, + data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]], + ) -> None: + """ + Streams raw audio data to the real-time service. + + Args: + `data`: Raw audio data in `bytes` or a generator/iterable of `bytes`. + + Note: Make sure that `data` matches the `sample_rate` that was given in the constructor. + """ + if isinstance(data, bytes): + self._impl.stream(data) + return + + for chunk in data: + self._impl.stream(chunk) + + def close(self) -> None: + """ + Closes the connection to the real-time service. + """ + + self._impl.close(terminate=True) diff --git a/assemblyai/types.py b/assemblyai/types.py index fcb2692..bfce3fb 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -1,7 +1,8 @@ +from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union -from pydantic import BaseModel, BaseSettings, Extra, Field +from pydantic import UUID4, BaseModel, BaseSettings, Extra, Field from typing_extensions import Self @@ -1532,3 +1533,136 @@ class LemurCoachRequest(BaseModel): class LemurCoachResponse(BaseModel): response: str model: LemurModel = LemurModel.default + + +class RealtimeMessageTypes(str, Enum): + """ + The type of message received from the real-time API + """ + + partial_transcript = "PartialTranscript" + final_transcript = "FinalTranscript" + session_begins = "SessionBegins" + + +class RealtimeSessionOpened(BaseModel): + """ + Once a real-time session is opened, the client will receive this message + """ + + message_type: Literal[ + RealtimeMessageTypes.session_begins + ] = RealtimeMessageTypes.session_begins + + session_id: UUID4 + "Unique identifier for the established session." + + expires_at: datetime + "Timestamp when this session will expire." + + +class RealtimeWord(BaseModel): + """ + A word in a real-time transcript + """ + + start: int + "Start time of word relative to session start, in milliseconds" + + end: int + "End time of word relative to session start, in milliseconds" + + confidence: float + "The confidence score of the word, between 0 and 1" + + text: str + "The word itself" + + +class RealtimeTranscript(BaseModel): + """ + Base class for real-time transcript messages. + """ + + message_type: Literal[ + RealtimeMessageTypes.partial_transcript, RealtimeMessageTypes.final_transcript + ] + "Describes the type of message" + + audio_start: int + "Start time of audio sample relative to session start, in milliseconds" + + audio_end: int + "End time of audio sample relative to session start, in milliseconds" + + confidence: float + "The confidence score of the entire transcription, between 0 and 1" + + text: str + "The transcript for your audio" + + words: List[Word] + """ + An array of objects, with the information for each word in the transcription text. + Will include the `start`/`end` time (in milliseconds) of the word, the `confidence` score of the word, + and the `text` (i.e. the word itself) + """ + + created: datetime + "Timestamp when this message was created" + + +class RealtimePartialTranscript(RealtimeTranscript): + """ + As you send audio data to the service, the service will immediately start responding with partial transcripts. + """ + + message_type: Literal[ + RealtimeMessageTypes.partial_transcript + ] = RealtimeMessageTypes.partial_transcript + + +class RealtimeFinalTranscript(RealtimeTranscript): + """ + After you've received your partial results, our model will continue to analyze incoming audio and, + when it detects the end of an "utterance" (usually a pause in speech), it will finalize the results + sent to you so far with higher accuracy, as well as add punctuation and casing to the transcription text. + """ + + message_type: Literal[ + RealtimeMessageTypes.final_transcript + ] = RealtimeMessageTypes.final_transcript + + punctuated: bool + "Whether the transcript has been punctuated and cased" + + text_formatted: bool + "Whether the transcript has been formatted (e.g. Dollar -> $)" + + +class RealtimeError(AssemblyAIError): + """ + Real-time error message + """ + + +RealtimeErrorMapping = { + 4000: "Sample rate must be a positive integer", + 4001: "Not Authorized", + 4002: "Insufficient Funds", + 4003: """This feature is paid-only and requires you to add a credit card. + Please visit https://app.assemblyai.com/ to add a credit card to your account""", + 4004: "Session Not Found", + 4008: "Session Expired", + 4010: "Session Previously Closed", + 4029: "Client sent audio too fast", + 4030: "Session is handled by another websocket", + 4031: "Session idle for too long", + 4032: "Audio duration is too short", + 4033: "Audio duration is too long", + 4100: "Endpoint received invalid JSON", + 4101: "Endpoint received a message with an invalid schema", + 4102: "This account has exceeded the number of allowed streams", + 4103: "The session has been reconnected. This websocket is no longer valid.", + 1013: "Temporary server condition forced blocking client's request", +} diff --git a/setup.py b/setup.py index 36e6cf7..225205b 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="assemblyai", - version="0.9.0", + version="0.10.0", description="AssemblyAI Python SDK", author="AssemblyAI", author_email="engineering.sdk@assemblyai.com", @@ -15,8 +15,12 @@ install_requires=[ "httpx>=0.19.0", "pydantic>=1.7.0", - "typing-extensions", + "typing-extensions>=3.7,<4.6", + "websockets>=11.0", ], + extras_require={ + "extras": ["pyaudio>=0.2.13"], + }, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", diff --git a/tests/unit/test_realtime_transcriber.py b/tests/unit/test_realtime_transcriber.py new file mode 100644 index 0000000..cb6f8a4 --- /dev/null +++ b/tests/unit/test_realtime_transcriber.py @@ -0,0 +1,490 @@ +import datetime +import json +import uuid +from typing import Optional +from unittest.mock import MagicMock +from urllib.parse import urlencode + +import pytest +import websockets.exceptions +from faker import Faker +from pytest_mock import MockFixture + +import assemblyai as aai + +aai.settings.api_key = "test" +aai.settings.base_url = "https://api.assemblyai.com/v2" + + +def _disable_rw_threads(mocker: MockFixture): + """ + Disable the read/write threads for the websocket + """ + + mocker.patch("threading.Thread.start", return_value=None) + + +def test_realtime_connect_has_parameters(mocker: MockFixture): + """ + Test that the connect method has the correct parameters set + """ + actual_url = None + actual_additional_headers = None + actual_open_timeout = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url, actual_additional_headers, actual_open_timeout + actual_url = url + actual_additional_headers = additional_headers + actual_open_timeout = open_timeout + + mocker.patch( + "assemblyai.transcriber.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + + transcriber = aai.RealtimeTranscriber( + on_data=lambda: None, + on_error=lambda error: print(error), + sample_rate=44_100, + word_boost=["AssemblyAI"], + ) + + transcriber.connect(timeout=15.0) + + assert ( + actual_url + == f"wss://api.assemblyai.com/v2/realtime/ws?{urlencode(dict(sample_rate=44100, word_boost=['AssemblyAI']))}" + ) + assert actual_additional_headers == {"Authorization": aai.settings.api_key} + assert actual_open_timeout == 15.0 + + +def test_realtime_connect_succeeds(mocker: MockFixture): + """ + Tests that the `RealtimeTranscriber` successfully connects to the `real-time` service. + """ + on_error_called = False + + def on_error(error: aai.RealtimeError): + nonlocal on_error_called + on_error_called = True + + transcriber = aai.RealtimeTranscriber( + on_data=lambda _: None, + on_error=on_error, + sample_rate=44_100, + ) + + mocker.patch( + "assemblyai.transcriber.websocket_connect", + return_value=MagicMock(), + ) + + # mock the read/write threads + _disable_rw_threads(mocker) + + # should pass + transcriber.connect() + + # no errors should be called + assert not on_error_called + + +def test_realtime_connect_fails(mocker: MockFixture): + """ + Tests that the `RealtimeTranscriber` fails to connect to the `real-time` service. + """ + + on_error_called = False + + def on_error(error: aai.RealtimeError): + nonlocal on_error_called + on_error_called = True + + assert isinstance(error, aai.RealtimeError) + assert "connection failed" in str(error) + + transcriber = aai.RealtimeTranscriber( + on_data=lambda _: None, + on_error=on_error, + sample_rate=44_100, + ) + mocker.patch( + "assemblyai.transcriber.websocket_connect", + side_effect=Exception("connection failed"), + ) + + transcriber.connect() + + assert on_error_called + + +def test_realtime__read_succeeds(mocker: MockFixture, faker: Faker): + """ + Tests the `_read` method of the `_RealtimeTranscriberImpl` class. + """ + + expected_transcripts = [ + aai.RealtimeFinalTranscript( + created=faker.date_time(), + text=faker.sentence(), + audio_start=0, + audio_end=1, + confidence=1.0, + words=[], + punctuated=True, + text_formatted=True, + ) + ] + + received_transcripts = [] + + def on_data(data: aai.RealtimeTranscript): + nonlocal received_transcripts + received_transcripts.append(data) + + transcriber = aai.RealtimeTranscriber( + on_data=on_data, + on_error=lambda _: None, + sample_rate=44_100, + ) + + transcriber._impl._websocket = MagicMock() + websocket_recv = [ + json.dumps(msg.dict(), default=str) for msg in expected_transcripts + ] + transcriber._impl._websocket.recv.side_effect = websocket_recv + + with pytest.raises(StopIteration): + transcriber._impl._read() + + assert received_transcripts == expected_transcripts + + +def test_realtime__read_fails(mocker: MockFixture): + """ + Tests the `_read` method of the `_RealtimeTranscriberImpl` class. + """ + + on_error_called = False + + def on_error(error: aai.RealtimeError): + nonlocal on_error_called + on_error_called = True + + transcriber = aai.RealtimeTranscriber( + on_data=lambda _: None, + on_error=on_error, + sample_rate=44_100, + ) + + transcriber._impl._websocket = MagicMock() + error = websockets.exceptions.ConnectionClosedOK(rcvd=None, sent=None) + transcriber._impl._websocket.recv.side_effect = error + + transcriber._impl._read() + + assert on_error_called + + +def test_realtime__write_succeeds(mocker: MockFixture): + """ + Tests the `_write` method of the `_RealtimeTranscriberImpl` class. + """ + audio_chunks = [ + bytes([1, 2, 3, 4, 5]), + bytes([6, 7, 8, 9, 10]), + ] + + actual_sent = [] + + def mocked_send(data: str): + nonlocal actual_sent + actual_sent.append(data) + + transcriber = aai.RealtimeTranscriber( + on_data=lambda _: None, + on_error=lambda _: None, + sample_rate=44_100, + ) + + transcriber._impl._websocket = MagicMock() + transcriber._impl._websocket.send = mocked_send + transcriber._impl._stop_event.is_set = MagicMock(side_effect=[False, False, True]) + + transcriber.stream(audio_chunks[0]) + transcriber.stream(audio_chunks[1]) + + transcriber._impl._write() + + # assert that the correct data was sent (base64 encoded) + assert len(actual_sent) == 2 + assert json.loads(actual_sent[0]) == {"audio_data": "AQIDBAU="} + assert json.loads(actual_sent[1]) == {"audio_data": "BgcICQo="} + + +def test_realtime__encode_data(mocker: MockFixture): + """ + Tests the `_encode_data` method of the `_RealtimeTranscriberImpl` class. + """ + + audio_chunks = [ + bytes([1, 2, 3, 4, 5]), + bytes([6, 7, 8, 9, 10]), + ] + + expected_encoded_data = [ + json.dumps({"audio_data": "AQIDBAU="}), + json.dumps({"audio_data": "BgcICQo="}), + ] + + transcriber = aai.RealtimeTranscriber( + on_data=lambda _: None, + on_error=lambda _: None, + sample_rate=44_100, + ) + + actual_encoded_data = [] + for chunk in audio_chunks: + actual_encoded_data.append(transcriber._impl._encode_data(chunk)) + + assert actual_encoded_data == expected_encoded_data + + +def test_realtime__handle_message_session_begins(mocker: MockFixture): + """ + Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class + with the `SessionBegins` message. + """ + + test_message = { + "message_type": "SessionBegins", + "session_id": str(uuid.uuid4()), + "expires_at": datetime.datetime.now().isoformat(), + } + + on_open_called = False + + def on_open(session_opened: aai.RealtimeSessionOpened): + nonlocal on_open_called + on_open_called = True + assert isinstance(session_opened, aai.RealtimeSessionOpened) + assert session_opened.session_id == uuid.UUID(test_message["session_id"]) + assert session_opened.expires_at.isoformat() == test_message["expires_at"] + + transcriber = aai.RealtimeTranscriber( + on_open=on_open, + on_data=lambda _: None, + on_error=lambda _: None, + sample_rate=44_100, + ) + + transcriber._impl._handle_message(test_message) + + assert on_open_called + + +def test_realtime__handle_message_partial_transcript(mocker: MockFixture): + """ + Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class + with the `PartialTranscript` message. + """ + + test_message = { + "message_type": "PartialTranscript", + "text": "hello world", + "audio_start": 0, + "audio_end": 1500, + "confidence": 0.99, + "created": datetime.datetime.now().isoformat(), + "words": [ + { + "text": "hello", + "start": 0, + "end": 500, + "confidence": 0.99, + }, + { + "text": "world", + "start": 500, + "end": 1500, + "confidence": 0.99, + }, + ], + } + + on_data_called = False + + def on_data(data: aai.RealtimePartialTranscript): + nonlocal on_data_called + on_data_called = True + assert isinstance(data, aai.RealtimePartialTranscript) + assert data.text == test_message["text"] + assert data.audio_start == test_message["audio_start"] + assert data.audio_end == test_message["audio_end"] + assert data.confidence == test_message["confidence"] + assert data.created.isoformat() == test_message["created"] + assert data.words == [ + aai.RealtimeWord( + text=test_message["words"][0]["text"], + start=test_message["words"][0]["start"], + end=test_message["words"][0]["end"], + confidence=test_message["words"][0]["confidence"], + ), + aai.RealtimeWord( + text=test_message["words"][1]["text"], + start=test_message["words"][1]["start"], + end=test_message["words"][1]["end"], + confidence=test_message["words"][1]["confidence"], + ), + ] + + transcriber = aai.RealtimeTranscriber( + on_data=on_data, + on_error=lambda _: None, + sample_rate=44_100, + ) + + transcriber._impl._handle_message(test_message) + + assert on_data_called + + +def test_realtime__handle_message_final_transcript(mocker: MockFixture): + """ + Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class + with the `FinalTranscript` message. + """ + + test_message = { + "message_type": "FinalTranscript", + "text": "Hello, world!", + "audio_start": 0, + "audio_end": 1500, + "confidence": 0.99, + "created": datetime.datetime.now().isoformat(), + "punctuated": True, + "text_formatted": True, + "words": [ + { + "text": "Hello,", + "start": 0, + "end": 500, + "confidence": 0.99, + }, + { + "text": "world!", + "start": 500, + "end": 1500, + "confidence": 0.99, + }, + ], + } + + on_data_called = False + + def on_data(data: aai.RealtimeFinalTranscript): + nonlocal on_data_called + on_data_called = True + assert isinstance(data, aai.RealtimeFinalTranscript) + assert data.text == test_message["text"] + assert data.audio_start == test_message["audio_start"] + assert data.audio_end == test_message["audio_end"] + assert data.confidence == test_message["confidence"] + assert data.created.isoformat() == test_message["created"] + assert data.punctuated == test_message["punctuated"] + assert data.text_formatted == test_message["text_formatted"] + assert data.words == [ + aai.RealtimeWord( + text=test_message["words"][0]["text"], + start=test_message["words"][0]["start"], + end=test_message["words"][0]["end"], + confidence=test_message["words"][0]["confidence"], + ), + aai.RealtimeWord( + text=test_message["words"][1]["text"], + start=test_message["words"][1]["start"], + end=test_message["words"][1]["end"], + confidence=test_message["words"][1]["confidence"], + ), + ] + + transcriber = aai.RealtimeTranscriber( + on_data=on_data, + on_error=lambda _: None, + sample_rate=44_100, + ) + + transcriber._impl._handle_message(test_message) + + assert on_data_called + + +def test_realtime__handle_message_error_message(mocker: MockFixture): + """ + Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class + with the error message. + """ + + test_message = { + "error": "test error", + } + + on_error_called = False + + def on_error(error: aai.RealtimeError): + nonlocal on_error_called + on_error_called = True + assert isinstance(error, aai.RealtimeError) + assert str(error) == test_message["error"] + + transcriber = aai.RealtimeTranscriber( + on_data=lambda _: None, + on_error=on_error, + sample_rate=44_100, + ) + + transcriber._impl._handle_message(test_message) + + assert on_error_called + + +def test_realtime__handle_message_unknown_message(mocker: MockFixture): + """ + Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class + with an unknown message. + """ + + test_message = { + "message_type": "Unknown", + } + + on_data_called = False + + def on_data(data: aai.RealtimeTranscript): + nonlocal on_data_called + on_data_called = True + + on_error_called = False + + def on_error(error: aai.RealtimeError): + nonlocal on_error_called + on_error_called = True + + transcriber = aai.RealtimeTranscriber( + on_data=on_data, + on_error=on_error, + sample_rate=44_100, + ) + + transcriber._impl._handle_message(test_message) + + assert not on_data_called + assert not on_error_called + + +# TODO: create tests for the `RealtimeTranscriber.close` method diff --git a/tox.ini b/tox.ini index a5327f1..b10d573 100644 --- a/tox.ini +++ b/tox.ini @@ -1,9 +1,11 @@ [tox] -envlist = py{38,39,310,311}-httpx{latest,0.24,0.23,0.22,0.21}-pydantic{latest,1.10,1.9,1.8,1.7}-typing-extensions +envlist = py{38,39,310,311}-websockets{latest,11.0}-pyaudio{latest,0.2}-httpx{latest,0.24,0.23,0.22,0.21}-pydantic{latest,1.10,1.9,1.8,1.7}-typing-extensions [testenv] deps = # library dependencies + websocketslatest: websockets + websockets11.0: websockets>=11.0.0,<12.0.0 httpxlatest: httpx httpx0.24: httpx>=0.24.0,<0.25.0 httpx0.23: httpx>=0.23.0,<0.24.0 @@ -14,11 +16,15 @@ deps = pydantic1.9: pydantic>=1.9.0,<1.10.0 pydantic1.8: pydantic>=1.8.0,<1.9.0 pydantic1.7: pydantic>=1.7.0,<1.8.0 - typing-extensions: typing-extensions>=3.7 + typing-extensions: typing-extensions>=3.7,<4.6 + # extra dependencies + pyaudiolatest: pyaudio + pyaudio0.2: pyaudio>=0.2.13,<0.3.0 # test dependencies pytest pytest-httpx pytest-xdist + pytest-mock pytest-cov factory-boy allowlist_externals = pytest