Skip to content

chore: sync sdk code with DeepLearning repo #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assemblyai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.41.0b4"
__version__ = "0.41.0"
68 changes: 51 additions & 17 deletions assemblyai/streaming/v3/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@
logger = logging.getLogger(__name__)


def _user_agent() -> str:
vi = sys.version_info
python_version = f"{vi.major}.{vi.minor}.{vi.micro}"
return (
f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})"
)


def _dump_model(model: BaseModel):
if hasattr(model, "model_dump"):
return model.model_dump(exclude_none=True)
Expand All @@ -53,10 +45,20 @@ def _dump_model_json(model: BaseModel):
return model.json(exclude_none=True)


def _user_agent() -> str:
vi = sys.version_info
python_version = f"{vi.major}.{vi.minor}.{vi.micro}"
return (
f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})"
)


class StreamingClient:
def __init__(self, options: StreamingClientOptions):
self._options = options

self._client = _HTTPClient(api_host=options.api_host, api_key=options.api_key)

self._handlers: Dict[StreamingEvents, List[Callable]] = {}

for event in StreamingEvents.__members__.values():
Expand All @@ -73,7 +75,9 @@ def connect(self, params: StreamingParameters) -> None:

uri = f"wss://{self._options.api_host}/v3/ws?{params_encoded}"
headers = {
"Authorization": self._options.api_key,
"Authorization": self._options.token
if self._options.token
else self._options.api_key,
"User-Agent": _user_agent(),
"AssemblyAI-Version": "2025-05-12",
}
Expand Down Expand Up @@ -253,14 +257,44 @@ def _parse_error(
message=f"Unknown error: {error}",
)

def create_temporary_token(
self,
expires_in_seconds: int,
max_session_duration_seconds: int,
) -> str:
return self._client.create_temporary_token(
expires_in_seconds=expires_in_seconds,
max_session_duration_seconds=max_session_duration_seconds,
)

class HTTPClient:
def __init__(self, options: StreamingClientOptions):
headers = {
"Authorization": options.api_key,
"User-Agent": _user_agent(),
}

base_url = f"https://{options.api_host}"
class _HTTPClient:
def __init__(self, api_host: str, api_key: Optional[str] = None):
vi = sys.version_info
python_version = f"{vi.major}.{vi.minor}.{vi.micro}"
user_agent = f"{httpx._client.USER_AGENT} AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})"

headers = {"User-Agent": user_agent}

self._http_client = httpx.Client(base_url=base_url, headers=headers, timeout=30)
if api_key:
headers["Authorization"] = api_key

self._http_client = httpx.Client(
base_url="https://" + api_host,
headers=headers,
)

def create_temporary_token(
self,
expires_in_seconds: int,
max_session_duration_seconds: int,
) -> str:
response = self._http_client.get(
"/v3/token",
params={
"expires_in": expires_in_seconds,
"max_session_duration": max_session_duration_seconds,
},
)
response.raise_for_status()
return response.json()["token"]
7 changes: 3 additions & 4 deletions assemblyai/streaming/v3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,14 @@ class ForceEndpoint(BaseModel):


class StreamingSessionParameters(BaseModel):
word_finalization_max_wait_time: Optional[int] = None
end_of_turn_confidence_threshold: Optional[float] = None
min_end_of_turn_silence_when_confident: Optional[int] = None
max_turn_silence: Optional[int] = None
formatted_finals: Optional[bool] = None
format_turns: Optional[bool] = None


class StreamingParameters(StreamingSessionParameters):
sample_rate: int
token: Optional[str] = None


class UpdateConfiguration(StreamingSessionParameters):
Expand All @@ -81,8 +79,9 @@ class UpdateConfiguration(StreamingSessionParameters):


class StreamingClientOptions(BaseModel):
api_host: str = "streaming.assemblyai.com"
api_key: Optional[str] = None
api_host: str
token: Optional[str] = None


class StreamingError(Exception):
Expand Down
40 changes: 38 additions & 2 deletions tests/unit/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,44 @@ def mocked_websocket_connect(
assert actual_open_timeout == 15


def test_client_connect_with_token(mocker: MockFixture):
actual_url = None
actual_additional_headers = None
actual_open_timeout = None

def mocked_websocket_connect(
url: str, additional_headers: dict, open_timeout: float
):
nonlocal actual_url, actual_additional_headers, actual_open_timeout
actual_url = url
actual_additional_headers = additional_headers
actual_open_timeout = open_timeout

mocker.patch(
"assemblyai.streaming.v3.client.websocket_connect",
new=mocked_websocket_connect,
)

_disable_rw_threads(mocker)

options = StreamingClientOptions(token="test", api_host="api.example.com")
client = StreamingClient(options)

params = StreamingParameters(sample_rate=16000)
client.connect(params)

expected_headers = {
"sample_rate": params.sample_rate,
}

assert actual_url == f"wss://api.example.com/v3/ws?{urlencode(expected_headers)}"
assert actual_additional_headers["Authorization"] == "test"
assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12"
assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"]

assert actual_open_timeout == 15


def test_client_connect_all_parameters(mocker: MockFixture):
actual_url = None
actual_additional_headers = None
Expand All @@ -80,7 +118,6 @@ def mocked_websocket_connect(

params = StreamingParameters(
sample_rate=16000,
word_finalization_max_wait_time=5000,
end_of_turn_confidence_threshold=0.5,
min_end_of_turn_silence_when_confident=2000,
max_turn_silence=3000,
Expand All @@ -89,7 +126,6 @@ def mocked_websocket_connect(
client.connect(params)

expected_headers = {
"word_finalization_max_wait_time": params.word_finalization_max_wait_time,
"end_of_turn_confidence_threshold": params.end_of_turn_confidence_threshold,
"min_end_of_turn_silence_when_confident": params.min_end_of_turn_silence_when_confident,
"max_turn_silence": params.max_turn_silence,
Expand Down
Loading