From 130c83a92a760972384d933b5ddf218b604acfff Mon Sep 17 00:00:00 2001 From: AssemblyAI Date: Mon, 2 Jun 2025 12:38:00 +0200 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: c70193a0e8ed6c0bcbd46c4eb0b6e8f21973e49e --- assemblyai/__version__.py | 2 +- assemblyai/streaming/v3/client.py | 68 +++++++++++++++++++++++-------- assemblyai/streaming/v3/models.py | 7 ++-- tests/unit/test_streaming.py | 40 +++++++++++++++++- 4 files changed, 93 insertions(+), 24 deletions(-) diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index 2ccdd8e..22ffde2 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.41.0b4" +__version__ = "0.41.0" diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py index e047ee1..c9b178a 100644 --- a/assemblyai/streaming/v3/client.py +++ b/assemblyai/streaming/v3/client.py @@ -33,14 +33,6 @@ logger = logging.getLogger(__name__) -def _user_agent() -> str: - vi = sys.version_info - python_version = f"{vi.major}.{vi.minor}.{vi.micro}" - return ( - f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" - ) - - def _dump_model(model: BaseModel): if hasattr(model, "model_dump"): return model.model_dump(exclude_none=True) @@ -53,10 +45,20 @@ def _dump_model_json(model: BaseModel): return model.json(exclude_none=True) +def _user_agent() -> str: + vi = sys.version_info + python_version = f"{vi.major}.{vi.minor}.{vi.micro}" + return ( + f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" + ) + + class StreamingClient: def __init__(self, options: StreamingClientOptions): self._options = options + self._client = _HTTPClient(api_host=options.api_host, api_key=options.api_key) + self._handlers: Dict[StreamingEvents, List[Callable]] = {} for event in StreamingEvents.__members__.values(): @@ -73,7 +75,9 @@ def connect(self, params: StreamingParameters) -> None: uri = f"wss://{self._options.api_host}/v3/ws?{params_encoded}" headers = { - "Authorization": self._options.api_key, + "Authorization": self._options.token + if self._options.token + else self._options.api_key, "User-Agent": _user_agent(), "AssemblyAI-Version": "2025-05-12", } @@ -253,14 +257,44 @@ def _parse_error( message=f"Unknown error: {error}", ) + def create_temporary_token( + self, + expires_in_seconds: int, + max_session_duration_seconds: int, + ) -> str: + return self._client.create_temporary_token( + expires_in_seconds=expires_in_seconds, + max_session_duration_seconds=max_session_duration_seconds, + ) -class HTTPClient: - def __init__(self, options: StreamingClientOptions): - headers = { - "Authorization": options.api_key, - "User-Agent": _user_agent(), - } - base_url = f"https://{options.api_host}" +class _HTTPClient: + def __init__(self, api_host: str, api_key: Optional[str] = None): + vi = sys.version_info + python_version = f"{vi.major}.{vi.minor}.{vi.micro}" + user_agent = f"{httpx._client.USER_AGENT} AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" + + headers = {"User-Agent": user_agent} - self._http_client = httpx.Client(base_url=base_url, headers=headers, timeout=30) + if api_key: + headers["Authorization"] = api_key + + self._http_client = httpx.Client( + base_url="https://" + api_host, + headers=headers, + ) + + def create_temporary_token( + self, + expires_in_seconds: int, + max_session_duration_seconds: int, + ) -> str: + response = self._http_client.get( + "/v3/token", + params={ + "expires_in": expires_in_seconds, + "max_session_duration": max_session_duration_seconds, + }, + ) + response.raise_for_status() + return response.json()["token"] diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index 8a16090..86d2c4d 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -56,16 +56,14 @@ class ForceEndpoint(BaseModel): class StreamingSessionParameters(BaseModel): - word_finalization_max_wait_time: Optional[int] = None end_of_turn_confidence_threshold: Optional[float] = None min_end_of_turn_silence_when_confident: Optional[int] = None max_turn_silence: Optional[int] = None - formatted_finals: Optional[bool] = None + format_turns: Optional[bool] = None class StreamingParameters(StreamingSessionParameters): sample_rate: int - token: Optional[str] = None class UpdateConfiguration(StreamingSessionParameters): @@ -81,8 +79,9 @@ class UpdateConfiguration(StreamingSessionParameters): class StreamingClientOptions(BaseModel): + api_host: str = "streaming.assemblyai.com" api_key: Optional[str] = None - api_host: str + token: Optional[str] = None class StreamingError(Exception): diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index 9409856..af04db9 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -55,6 +55,44 @@ def mocked_websocket_connect( assert actual_open_timeout == 15 +def test_client_connect_with_token(mocker: MockFixture): + actual_url = None + actual_additional_headers = None + actual_open_timeout = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url, actual_additional_headers, actual_open_timeout + actual_url = url + actual_additional_headers = additional_headers + actual_open_timeout = open_timeout + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + + _disable_rw_threads(mocker) + + options = StreamingClientOptions(token="test", api_host="api.example.com") + client = StreamingClient(options) + + params = StreamingParameters(sample_rate=16000) + client.connect(params) + + expected_headers = { + "sample_rate": params.sample_rate, + } + + assert actual_url == f"wss://api.example.com/v3/ws?{urlencode(expected_headers)}" + assert actual_additional_headers["Authorization"] == "test" + assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12" + assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] + + assert actual_open_timeout == 15 + + def test_client_connect_all_parameters(mocker: MockFixture): actual_url = None actual_additional_headers = None @@ -80,7 +118,6 @@ def mocked_websocket_connect( params = StreamingParameters( sample_rate=16000, - word_finalization_max_wait_time=5000, end_of_turn_confidence_threshold=0.5, min_end_of_turn_silence_when_confident=2000, max_turn_silence=3000, @@ -89,7 +126,6 @@ def mocked_websocket_connect( client.connect(params) expected_headers = { - "word_finalization_max_wait_time": params.word_finalization_max_wait_time, "end_of_turn_confidence_threshold": params.end_of_turn_confidence_threshold, "min_end_of_turn_silence_when_confident": params.min_end_of_turn_silence_when_confident, "max_turn_silence": params.max_turn_silence,