Skip to content

Commit b9d9e38

Browse files
s0h3ylAssemblyAI
authored andcommitted
feat: add summarization functionality (#12)
Co-authored-by: AssemblyAI <[email protected]>
1 parent 98cb49b commit b9d9e38

File tree

6 files changed

+199
-5
lines changed

6 files changed

+199
-5
lines changed

README.md

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
<img src="https://github.com/AssemblyAI/assemblyai-python-sdk/blob/master/assemblyai.png?raw=true" width="500"/>
22

33
---
4+
45
[![CI Passing](https://github.com/AssemblyAI/assemblyai-python-sdk/actions/workflows/test.yml/badge.svg)](https://github.com/AssemblyAI/assemblyai-python-sdk/actions/workflows/test.yml)
56
[![GitHub License](https://img.shields.io/github/license/AssemblyAI/assemblyai-python-sdk)](https://github.com/AssemblyAI/assemblyai-python-sdk/blob/master/LICENSE)
67
[![PyPI version](https://badge.fury.io/py/assemblyai.svg)](https://badge.fury.io/py/assemblyai)
@@ -26,14 +27,12 @@ With a single API call, get access to AI models built on the latest AI breakthro
2627
- [Playgrounds](#playgrounds)
2728
- [Advanced](#advanced-todo)
2829

29-
3030
# Documentation
3131

3232
Visit our [AssemblyAI API Documentation](https://www.assemblyai.com/docs) to get an overview of our models!
3333

3434
# Quick Start
3535

36-
3736
## Installation
3837

3938
```bash
@@ -66,6 +65,7 @@ transcript = transcriber.transcribe("./my-local-audio-file.wav")
6665

6766
print(transcript.text)
6867
```
68+
6969
</details>
7070

7171
<details>
@@ -79,6 +79,7 @@ transcript = transcriber.transcribe("https://example.org/audio.mp3")
7979

8080
print(transcript.text)
8181
```
82+
8283
</details>
8384

8485
<details>
@@ -96,6 +97,7 @@ print(transcript.export_subtitles_srt())
9697
# in VTT format
9798
print(transcript.export_subtitles_vtt())
9899
```
100+
99101
</details>
100102

101103
<details>
@@ -115,6 +117,7 @@ paragraphs = transcript.get_paragraphs()
115117
for paragraph in paragraphs:
116118
print(paragraph.text)
117119
```
120+
118121
</details>
119122

120123
<details>
@@ -131,6 +134,7 @@ matches = transcript.word_search(["price", "product"])
131134
for match in matches:
132135
print(f"Found '{match.text}' {match.count} times in the transcript")
133136
```
137+
134138
</details>
135139

136140
<details>
@@ -152,9 +156,40 @@ transcript = transcriber.transcribe("https://example.org/audio.mp3", config)
152156

153157
print(transcript.text)
154158
```
159+
160+
</details>
161+
162+
<details>
163+
<summary>Summarize the content of a transcript</summary>
164+
165+
```python
166+
import assemblyai as aai
167+
168+
transcriber = aai.Transcriber()
169+
transcript = transcriber.transcribe(
170+
"https://example.org/audio.mp3",
171+
config=aai.TranscriptionConfig(summarize=True)
172+
)
173+
174+
print(transcript.summary)
175+
```
176+
177+
By default, the summarization model will be `informative` and the summarization type will be `bullets`. [Read more about summarization models and types here](https://www.assemblyai.com/docs/Models/summarization#types-and-models).
178+
179+
To change the model and/or type, pass additional parameters to the `TranscriptionConfig`:
180+
181+
```python
182+
config=aai.TranscriptionConfig(
183+
summarize=True,
184+
summary_model=aai.SummarizationModel.catchy,
185+
summary_type=aai.Summarizationtype.headline
186+
)
187+
```
188+
155189
</details>
156190

157191
---
192+
158193
### **LeMUR Examples**
159194

160195
<details>
@@ -175,6 +210,7 @@ summary = transcript_group.lemur.summarize(context="Customers asking for cars",
175210

176211
print(summary)
177212
```
213+
178214
</details>
179215

180216
<details>
@@ -195,6 +231,7 @@ feedback = transcript_group.lemur.ask_coach(context="Who was the best interviewe
195231

196232
print(feedback)
197233
```
234+
198235
</details>
199236

200237
<details>
@@ -218,6 +255,7 @@ for result in result:
218255
print(f"Question: {result.question}")
219256
print(f"Answer: {result.answer}")
220257
```
258+
221259
</details>
222260

223261
---
@@ -247,8 +285,8 @@ config.set_pii_redact(
247285
transcriber = aai.Transcriber()
248286
transcript = transcriber.transcribe("https://example.org/audio.mp3", config)
249287
```
250-
</details>
251288

289+
</details>
252290

253291
---
254292

assemblyai/transcriber.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ def text(self) -> Optional[str]:
204204

205205
return self._impl.transcript.text
206206

207+
@property
208+
def summary(self) -> Optional[str]:
209+
"The summarization of the transcript"
210+
211+
return self._impl.transcript.summary
212+
207213
@property
208214
def status(self) -> types.TranscriptStatus:
209215
"The current status of the transcript"

assemblyai/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,16 @@ def set_summarize(
10041004

10051005
return self
10061006

1007+
# Validate that required parameters are also set
1008+
if self._raw_transcription_config.punctuate == False:
1009+
raise ValueError(
1010+
"If `summarization` is enabled, then `punctuate` must not be disabled"
1011+
)
1012+
if self._raw_transcription_config.format_text == False:
1013+
raise ValueError(
1014+
"If `summarization` is enabled, then `format_text` must not be disabled"
1015+
)
1016+
10071017
self._raw_transcription_config.summarization = True
10081018
self._raw_transcription_config.summary_model = model
10091019
self._raw_transcription_config.summary_type = type
@@ -1379,6 +1389,9 @@ class TranscriptResponse(BaseTranscript):
13791389
webhook_auth: Optional[bool]
13801390
"Whether the webhook was sent with an HTTP authentication header"
13811391

1392+
summary: Optional[str]
1393+
"The summarization of the transcript"
1394+
13821395
# auto_highlights_result: Optional[AutohighlightResponse] = None
13831396
# "The list of results when enabling Automatic Transcript Highlights"
13841397

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="assemblyai",
10-
version="0.5.1",
10+
version="0.6.0",
1111
description="AssemblyAI Python SDK",
1212
author="AssemblyAI",
1313
author_email="[email protected]",

tests/unit/factories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class Meta:
200200
audio_duration = factory.Faker("pyint")
201201

202202

203-
def generate_dict_factory(f: factory.Factory) -> Callable[[None], Dict[str, Any]]:
203+
def generate_dict_factory(f: factory.Factory) -> Callable[[], Dict[str, Any]]:
204204
"""
205205
Creates a dict factory from the given *Factory class.
206206

tests/unit/test_summarization.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import json
2+
from typing import Any, Dict
3+
4+
import httpx
5+
import pytest
6+
from pytest_httpx import HTTPXMock
7+
8+
import assemblyai as aai
9+
from tests.unit import factories
10+
11+
aai.settings.api_key = "test"
12+
13+
14+
def __submit_request(httpx_mock: HTTPXMock, **params) -> Dict[str, Any]:
15+
"""
16+
Helper function to abstract calling transcriber with given parameters,
17+
and perform some common assertions.
18+
19+
Returns the body (dictionary) of the initial submission request.
20+
"""
21+
summary = "example summary"
22+
23+
mock_transcript_response = factories.generate_dict_factory(
24+
factories.TranscriptCompletedResponseFactory
25+
)()
26+
27+
# Mock initial submission response
28+
httpx_mock.add_response(
29+
url=f"{aai.settings.base_url}/transcript",
30+
status_code=httpx.codes.OK,
31+
method="POST",
32+
json=mock_transcript_response,
33+
)
34+
35+
# Mock polling-for-completeness response, with mock summary result
36+
httpx_mock.add_response(
37+
url=f"{aai.settings.base_url}/transcript/{mock_transcript_response['id']}",
38+
status_code=httpx.codes.OK,
39+
method="GET",
40+
json={**mock_transcript_response, "summary": summary},
41+
)
42+
43+
# == Make API request via SDK ==
44+
transcript = aai.Transcriber().transcribe(
45+
data="https://example.org/audio.wav",
46+
config=aai.TranscriptionConfig(
47+
**params,
48+
),
49+
)
50+
51+
# Check that submission and polling requests were made
52+
assert len(httpx_mock.get_requests()) == 2
53+
54+
# Check that summary field from response was traced back through SDK classes
55+
assert transcript.summary == summary
56+
57+
# Extract and return body of initial submission request
58+
request = httpx_mock.get_requests()[0]
59+
return json.loads(request.content.decode())
60+
61+
62+
@pytest.mark.parametrize("required_field", ["punctuate", "format_text"])
63+
def test_summarization_fails_without_required_field(
64+
httpx_mock: HTTPXMock, required_field: str
65+
):
66+
"""
67+
Tests whether the SDK raises an error before making a request
68+
if `summarization` is enabled and the given required field is disabled
69+
"""
70+
with pytest.raises(ValueError) as error:
71+
__submit_request(httpx_mock, summarization=True, **{required_field: False})
72+
73+
# Check that the error message informs the user of the invalid parameter
74+
assert required_field in str(error)
75+
76+
# Check that the error was raised before any requests were made
77+
assert len(httpx_mock.get_requests()) == 0
78+
79+
# Inform httpx_mock that it's okay we didn't make any requests
80+
httpx_mock.reset(False)
81+
82+
83+
def test_summarization_disabled_by_default(httpx_mock: HTTPXMock):
84+
"""
85+
Tests that excluding `summarization` from the `TranscriptionConfig` will
86+
result in the default behavior of it being excluded from the request body
87+
"""
88+
request_body = __submit_request(httpx_mock)
89+
assert request_body.get("summarization") is None
90+
91+
92+
def test_default_summarization_params(httpx_mock: HTTPXMock):
93+
"""
94+
Tests that including `summarization=True` in the `TranscriptionConfig`
95+
will result in `summarization=True` in the request body.
96+
"""
97+
request_body = __submit_request(httpx_mock, summarization=True)
98+
assert request_body.get("summarization") == True
99+
100+
101+
def test_summarization_with_params(httpx_mock: HTTPXMock):
102+
"""
103+
Tests that including additional summarization parameters along with
104+
`summarization=True` in the `TranscriptionConfig` will result in all
105+
parameters being included in the request as well.
106+
"""
107+
108+
summary_model = aai.SummarizationModel.conversational
109+
summary_type = aai.SummarizationType.bullets
110+
111+
request_body = __submit_request(
112+
httpx_mock,
113+
summarization=True,
114+
summary_model=summary_model,
115+
summary_type=summary_type,
116+
)
117+
118+
assert request_body.get("summarization") == True
119+
assert request_body.get("summary_model") == summary_model
120+
assert request_body.get("summary_type") == summary_type
121+
122+
123+
def test_summarization_params_excluded_when_disabled(httpx_mock: HTTPXMock):
124+
"""
125+
Tests that additional summarization parameters are excluded from the submission
126+
request body if `summarization` itself is not enabled.
127+
"""
128+
request_body = __submit_request(
129+
httpx_mock,
130+
summarization=False,
131+
summary_model=aai.SummarizationModel.conversational,
132+
summary_type=aai.SummarizationType.bullets,
133+
)
134+
135+
assert request_body.get("summarization") is None
136+
assert request_body.get("summary_model") is None
137+
assert request_body.get("summary_type") is None

0 commit comments

Comments
 (0)