Skip to content

Commit 7a41020

Browse files
ploeberaleks-mitov
andauthored
chore: sync code base with OSS repository (#53)
Co-authored-by: Aleks Mitov <[email protected]>
1 parent 1642920 commit 7a41020

File tree

4 files changed

+80
-52
lines changed

4 files changed

+80
-52
lines changed

assemblyai/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _get_error_message(response: httpx.Response) -> str:
2727
try:
2828
return response.json()["error"]
2929
except Exception:
30-
return response.text
30+
return f"\nReason: {response.text}\nRequest: {response.request}"
3131

3232

3333
def create_transcript(
@@ -43,7 +43,7 @@ def create_transcript(
4343
)
4444
if response.status_code != httpx.codes.ok:
4545
raise types.TranscriptError(
46-
f"failed to transcript url {request.audio_url}: {_get_error_message(response)}"
46+
f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}"
4747
)
4848

4949
return types.TranscriptResponse.parse_obj(response.json())

assemblyai/transcriber.py

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import base64
43
import concurrent.futures
54
import functools
65
import json
@@ -987,6 +986,7 @@ def __init__(
987986
encoding: Optional[types.AudioEncoding] = None,
988987
token: Optional[str] = None,
989988
client: _client.Client,
989+
end_utterance_silence_threshold: Optional[int],
990990
) -> None:
991991
self._client = client
992992
self._websocket: Optional[websockets.sync.client.ClientConnection] = None
@@ -999,8 +999,9 @@ def __init__(
999999
self._word_boost = word_boost
10001000
self._encoding = encoding
10011001
self._token = token
1002+
self._end_utterance_silence_threshold = end_utterance_silence_threshold
10021003

1003-
self._write_queue: queue.Queue[bytes] = queue.Queue()
1004+
self._write_queue: queue.Queue[Union[bytes, Dict]] = queue.Queue()
10041005
self._write_thread = threading.Thread(target=self._write)
10051006
self._read_thread = threading.Thread(target=self._read)
10061007
self._stop_event = threading.Event()
@@ -1048,13 +1049,40 @@ def connect(
10481049
self._read_thread.start()
10491050
self._write_thread.start()
10501051

1052+
if self._end_utterance_silence_threshold is not None:
1053+
self.configure_end_utterance_silence_threshold(
1054+
self._end_utterance_silence_threshold
1055+
)
1056+
10511057
def stream(self, data: bytes) -> None:
10521058
"""
10531059
Streams audio data to the real-time service by putting it into a queue.
10541060
"""
10551061

10561062
self._write_queue.put(data)
10571063

1064+
def configure_end_utterance_silence_threshold(
1065+
self, threshold_milliseconds: int
1066+
) -> None:
1067+
"""
1068+
Configures the end of utterance silence threshold.
1069+
Can be called multiple times during a session at any point after the session starts.
1070+
1071+
Args:
1072+
`threshold_milliseconds`: The threshold in milliseconds.
1073+
"""
1074+
1075+
self._write_queue.put(
1076+
_RealtimeEndUtteranceSilenceThreshold(threshold_milliseconds).as_dict()
1077+
)
1078+
1079+
def force_end_utterance(self) -> None:
1080+
"""
1081+
Forces the end of the current utterance.
1082+
"""
1083+
1084+
self._write_queue.put(_RealtimeForceEndUtterance().as_dict())
1085+
10581086
def close(self, terminate: bool = False) -> None:
10591087
"""
10601088
Closes the connection to the real-time service gracefully.
@@ -1116,25 +1144,12 @@ def _write(self) -> None:
11161144
if isinstance(data, dict):
11171145
self._websocket.send(json.dumps(data))
11181146
elif isinstance(data, bytes):
1119-
self._websocket.send(self._encode_data(data))
1147+
self._websocket.send(data)
11201148
else:
11211149
raise ValueError("unsupported message type")
11221150
except websockets.exceptions.ConnectionClosed as exc:
11231151
return self._handle_error(exc)
11241152

1125-
def _encode_data(self, data: bytes) -> str:
1126-
"""
1127-
Encodes the given audio chunk as a base64 string.
1128-
1129-
This is a helper method for `_write`.
1130-
"""
1131-
1132-
return json.dumps(
1133-
{
1134-
"audio_data": base64.b64encode(data).decode("utf-8"),
1135-
}
1136-
)
1137-
11381153
def _handle_message(
11391154
self,
11401155
message: Dict[str, Any],
@@ -1208,6 +1223,25 @@ def create_temporary_token(
12081223
)
12091224

12101225

1226+
class _RealtimeForceEndUtterance:
1227+
def as_dict(self) -> Dict[str, bool]:
1228+
return {
1229+
"force_end_utterance": True,
1230+
}
1231+
1232+
1233+
class _RealtimeEndUtteranceSilenceThreshold:
1234+
def __init__(self, threshold_milliseconds: int) -> None:
1235+
self._value = threshold_milliseconds
1236+
1237+
@property
1238+
def value(self) -> int:
1239+
return self._value
1240+
1241+
def as_dict(self) -> Dict[str, int]:
1242+
return {"end_utterance_silence_threshold": self._value}
1243+
1244+
12111245
class RealtimeTranscriber:
12121246
def __init__(
12131247
self,
@@ -1221,6 +1255,7 @@ def __init__(
12211255
encoding: Optional[types.AudioEncoding] = None,
12221256
token: Optional[str] = None,
12231257
client: Optional[_client.Client] = None,
1258+
end_utterance_silence_threshold: Optional[int] = None,
12241259
) -> None:
12251260
"""
12261261
Creates a new real-time transcriber.
@@ -1235,6 +1270,7 @@ def __init__(
12351270
`encoding`: (Optional) The encoding of the audio data.
12361271
`token`: (Optional) A temporary authentication token.
12371272
`client`: (Optional) The client to use for the real-time service.
1273+
`end_utterance_silence_threshold`: (Optional) The end utterance silence threshold in milliseconds.
12381274
"""
12391275

12401276
self._client = client or _client.Client.get_default(
@@ -1251,6 +1287,7 @@ def __init__(
12511287
encoding=encoding,
12521288
token=token,
12531289
client=self._client,
1290+
end_utterance_silence_threshold=end_utterance_silence_threshold,
12541291
)
12551292

12561293
def connect(
@@ -1268,8 +1305,7 @@ def connect(
12681305
self._impl.connect(timeout=timeout)
12691306

12701307
def stream(
1271-
self,
1272-
data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]],
1308+
self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]]
12731309
) -> None:
12741310
"""
12751311
Streams raw audio data to the real-time service.
@@ -1286,6 +1322,26 @@ def stream(
12861322
for chunk in data:
12871323
self._impl.stream(chunk)
12881324

1325+
def configure_end_utterance_silence_threshold(
1326+
self, threshold_milliseconds: int
1327+
) -> None:
1328+
"""
1329+
Configures the silence duration threshold used to detect the end of an utterance.
1330+
In practice, it's used to tune how the transcriptions are split into final transcripts.
1331+
Can be called multiple times during a session at any point after the session starts.
1332+
1333+
Args:
1334+
`threshold_milliseconds`: The threshold in milliseconds.
1335+
"""
1336+
self._impl.configure_end_utterance_silence_threshold(threshold_milliseconds)
1337+
1338+
def force_end_utterance(self) -> None:
1339+
"""
1340+
Forces the end of the current utterance.
1341+
After calling this method, the server will end the current utterance and return a final transcript.
1342+
"""
1343+
self._impl.force_end_utterance()
1344+
12891345
def close(self) -> None:
12901346
"""
12911347
Closes the connection to the real-time service.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="assemblyai",
10-
version="0.21.0",
10+
version="0.22.0",
1111
description="AssemblyAI Python SDK",
1212
author="AssemblyAI",
1313
author_email="[email protected]",

tests/unit/test_realtime_transcriber.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -274,38 +274,10 @@ def mocked_send(data: str):
274274

275275
transcriber._impl._write()
276276

277-
# assert that the correct data was sent (base64 encoded)
277+
# assert that the correct data was sent (= the exact input bytes)
278278
assert len(actual_sent) == 2
279-
assert json.loads(actual_sent[0]) == {"audio_data": "AQIDBAU="}
280-
assert json.loads(actual_sent[1]) == {"audio_data": "BgcICQo="}
281-
282-
283-
def test_realtime__encode_data(mocker: MockFixture):
284-
"""
285-
Tests the `_encode_data` method of the `_RealtimeTranscriberImpl` class.
286-
"""
287-
288-
audio_chunks = [
289-
bytes([1, 2, 3, 4, 5]),
290-
bytes([6, 7, 8, 9, 10]),
291-
]
292-
293-
expected_encoded_data = [
294-
json.dumps({"audio_data": "AQIDBAU="}),
295-
json.dumps({"audio_data": "BgcICQo="}),
296-
]
297-
298-
transcriber = aai.RealtimeTranscriber(
299-
on_data=lambda _: None,
300-
on_error=lambda _: None,
301-
sample_rate=44_100,
302-
)
303-
304-
actual_encoded_data = []
305-
for chunk in audio_chunks:
306-
actual_encoded_data.append(transcriber._impl._encode_data(chunk))
307-
308-
assert actual_encoded_data == expected_encoded_data
279+
assert actual_sent[0] == audio_chunks[0]
280+
assert actual_sent[1] == audio_chunks[1]
309281

310282

311283
def test_realtime__handle_message_session_begins(mocker: MockFixture):

0 commit comments

Comments
 (0)