Skip to content

Commit b675bf6

Browse files
dmccrystals0h3yl
authored andcommitted
feat: add auto_highlights functionality
GitOrigin-RevId: 3a68cd7986efea85c30a712dda45b8de245019f7
1 parent df5869b commit b675bf6

File tree

5 files changed

+193
-16
lines changed

5 files changed

+193
-16
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,29 @@ for entity in transcript.entities:
412412

413413
[Read more about entity detection here.](https://www.assemblyai.com/docs/Models/entity_detection)
414414

415+
</details>
416+
<details>
417+
<summary>Identify Important Words and Phrases in a Transcript</summary>
418+
419+
```python
420+
import assemblyai as aai
421+
422+
transcriber = aai.Transcriber()
423+
transcript = transcriber.transcribe(
424+
"https://example.org/audio.mp3",
425+
config=aai.TranscriptionConfig(auto_highlights=True)
426+
)
427+
428+
for result in transcript.auto_highlights_result.results:
429+
print(result.text) # the important phrase
430+
print(result.rank) # relevancy of the phrase
431+
print(result.count) # number of instances of the phrase
432+
for timestamp in result.timestamps:
433+
print(f"Timestamp: {timestamp.start} - {timestamp.end}")
434+
```
435+
436+
[Read more about auto highlights here.](https://www.assemblyai.com/docs/Models/key_phrases)
437+
415438
</details>
416439

417440
---

assemblyai/transcriber.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ def sentiment_analysis_results(self) -> Optional[List[types.Sentiment]]:
226226
def entities(self) -> Optional[List[types.Entity]]:
227227
return self._impl.transcript.entities
228228

229+
@property
230+
def auto_highlights_result(self) -> Optional[types.AutohighlightResponse]:
231+
return self._impl.transcript.auto_highlights_result
232+
229233
@property
230234
def status(self) -> types.TranscriptStatus:
231235
"The current status of the transcript"

assemblyai/types.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ class RawTranscriptionConfig(BaseModel):
370370
summary_type: Optional[SummarizationType]
371371
"The summarization type to use in case `summarization` is enabled"
372372

373-
# auto_highlights: bool = False
374-
# "Detect important phrases and words in your transcription text."
373+
auto_highlights: Optional[bool]
374+
"Detect important phrases and words in your transcription text."
375375

376376
language_detection: Optional[bool]
377377
"""
@@ -424,7 +424,7 @@ def __init__(
424424
summarization: Optional[bool] = None,
425425
summary_model: Optional[SummarizationModel] = None,
426426
summary_type: Optional[SummarizationType] = None,
427-
# auto_highlights: bool = False,
427+
auto_highlights: Optional[bool] = None,
428428
language_detection: Optional[bool] = None,
429429
raw_transcription_config: Optional[RawTranscriptionConfig] = None,
430430
) -> None:
@@ -502,7 +502,7 @@ def __init__(
502502
summary_model,
503503
summary_type,
504504
)
505-
# self.auto_highlights = auto_highlights
505+
self.auto_highlights = auto_highlights
506506
self.language_detection = language_detection
507507

508508
@property
@@ -793,17 +793,17 @@ def summary_type(self) -> Optional[SummarizationType]:
793793

794794
return self._raw_transcription_config.summary_type
795795

796-
# @property
797-
# def auto_highlights(self) -> bool:
798-
# "Returns whether the Auto Highlights feature is enabled or not."
796+
@property
797+
def auto_highlights(self) -> Optional[bool]:
798+
"Returns whether the Auto Highlights feature is enabled or not."
799799

800-
# return self._raw_transcription_config.auto_highlights
800+
return self._raw_transcription_config.auto_highlights
801801

802-
# @auto_highlights.setter
803-
# def auto_highlights(self, enable: bool) -> None:
804-
# "Detect important phrases and words in your transcription text."
802+
@auto_highlights.setter
803+
def auto_highlights(self, enable: Optional[bool]) -> None:
804+
"Detect important phrases and words in your transcription text."
805805

806-
# self._raw_transcription_config.auto_highlights = enable
806+
self._raw_transcription_config.auto_highlights = enable
807807

808808
@property
809809
def language_detection(self) -> Optional[bool]:
@@ -1380,8 +1380,8 @@ class BaseTranscript(BaseModel):
13801380
summary_type: Optional[SummarizationType]
13811381
"The summarization type to use in case `summarization` is enabled"
13821382

1383-
# auto_highlights: bool = False
1384-
# "Detect important phrases and words in your transcription text."
1383+
auto_highlights: Optional[bool]
1384+
"Detect important phrases and words in your transcription text."
13851385

13861386
language_detection: Optional[bool]
13871387
"""
@@ -1442,8 +1442,8 @@ class TranscriptResponse(BaseTranscript):
14421442
summary: Optional[str]
14431443
"The summarization of the transcript"
14441444

1445-
# auto_highlights_result: Optional[AutohighlightResponse] = None
1446-
# "The list of results when enabling Automatic Transcript Highlights"
1445+
auto_highlights_result: Optional[AutohighlightResponse]
1446+
"The list of results when enabling Automatic Transcript Highlights"
14471447

14481448
content_safety_labels: Optional[ContentSafetyResponse]
14491449
"The list of results when Content Safety is enabled"

tests/unit/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ py_test(
2323
name = "unit_test",
2424
srcs = [
2525
"test_auto_chapters.py",
26+
"test_auto_highlights.py",
2627
"test_client.py",
2728
"test_config.py",
2829
"test_content_safety.py",

tests/unit/test_auto_highlights.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import json
2+
from typing import Any, Dict, Tuple
3+
4+
import factory
5+
import httpx
6+
from pytest_httpx import HTTPXMock
7+
8+
import assemblyai as aai
9+
from tests.unit import factories
10+
11+
aai.settings.api_key = "test"
12+
13+
14+
class AutohighlightResultFactory(factory.Factory):
15+
class Meta:
16+
model = aai.types.AutohighlightResult
17+
18+
count = factory.Faker("pyint")
19+
rank = factory.Faker("pyfloat")
20+
text = factory.Faker("sentence")
21+
timestamps = factory.List([factory.SubFactory(factories.TimestampFactory)])
22+
23+
24+
class AutohighlightResponseFactory(factory.Factory):
25+
class Meta:
26+
model = aai.types.AutohighlightResponse
27+
28+
status = aai.types.StatusResult.success
29+
results = factory.List([factory.SubFactory(AutohighlightResultFactory)])
30+
31+
32+
class AutohighlightTranscriptResponseFactory(
33+
factories.TranscriptCompletedResponseFactory
34+
):
35+
auto_highlights_result = factory.SubFactory(AutohighlightResponseFactory)
36+
37+
38+
def __submit_mock_request(
39+
httpx_mock: HTTPXMock,
40+
mock_response: Dict[str, Any],
41+
config: aai.TranscriptionConfig,
42+
) -> Tuple[Dict[str, Any], aai.Transcript]:
43+
"""
44+
Helper function to abstract mock transcriber calls with given `TranscriptionConfig`,
45+
and perform some common assertions.
46+
"""
47+
48+
mock_transcript_id = mock_response.get("id", "mock_id")
49+
50+
# Mock initial submission response (transcript is processing)
51+
mock_processing_response = factories.generate_dict_factory(
52+
factories.TranscriptProcessingResponseFactory
53+
)()
54+
55+
httpx_mock.add_response(
56+
url=f"{aai.settings.base_url}/transcript",
57+
status_code=httpx.codes.OK,
58+
method="POST",
59+
json={
60+
**mock_processing_response,
61+
"id": mock_transcript_id, # inject ID from main mock response
62+
},
63+
)
64+
65+
# Mock polling-for-completeness response, with completed transcript
66+
httpx_mock.add_response(
67+
url=f"{aai.settings.base_url}/transcript/{mock_transcript_id}",
68+
status_code=httpx.codes.OK,
69+
method="GET",
70+
json=mock_response,
71+
)
72+
73+
# == Make API request via SDK ==
74+
transcript = aai.Transcriber().transcribe(
75+
data="https://example.org/audio.wav",
76+
config=config,
77+
)
78+
79+
# Check that submission and polling requests were made
80+
assert len(httpx_mock.get_requests()) == 2
81+
82+
# Extract body of initial submission request
83+
request = httpx_mock.get_requests()[0]
84+
request_body = json.loads(request.content.decode())
85+
86+
return request_body, transcript
87+
88+
89+
def test_auto_highlights_disabled_by_default(httpx_mock: HTTPXMock):
90+
"""
91+
Tests that excluding `auto_highlights` from the `TranscriptionConfig` will
92+
result in the default behavior of it being excluded from the request body
93+
"""
94+
request_body, transcript = __submit_mock_request(
95+
httpx_mock,
96+
mock_response=factories.generate_dict_factory(
97+
factories.TranscriptCompletedResponseFactory
98+
)(),
99+
config=aai.TranscriptionConfig(),
100+
)
101+
assert request_body.get("auto_highlights") is None
102+
assert transcript.auto_highlights_result is None
103+
104+
105+
def test_auto_highlights_enabled(httpx_mock: HTTPXMock):
106+
"""
107+
Tests that including `auto_highlights=True` in the `TranscriptionConfig`
108+
will result in `auto_highlights=True` in the request body, and that the
109+
response is properly parsed into a `Transcript` object
110+
"""
111+
mock_response = factories.generate_dict_factory(
112+
AutohighlightTranscriptResponseFactory
113+
)()
114+
request_body, transcript = __submit_mock_request(
115+
httpx_mock,
116+
mock_response=mock_response,
117+
config=aai.TranscriptionConfig(auto_highlights=True),
118+
)
119+
120+
# Check that request body was properly defined
121+
assert request_body.get("auto_highlights") == True
122+
123+
# Check that transcript was properly parsed from JSON response
124+
assert transcript.error is None
125+
assert transcript.auto_highlights_result is not None
126+
assert (
127+
transcript.auto_highlights_result.status
128+
== mock_response["auto_highlights_result"]["status"]
129+
)
130+
131+
assert transcript.auto_highlights_result.results is not None
132+
assert len(transcript.auto_highlights_result.results) > 0
133+
assert len(transcript.auto_highlights_result.results) == len(
134+
mock_response["auto_highlights_result"]["results"]
135+
)
136+
137+
for response_result, transcript_result in zip(
138+
mock_response["auto_highlights_result"]["results"],
139+
transcript.auto_highlights_result.results,
140+
):
141+
assert transcript_result.count == response_result["count"]
142+
assert transcript_result.rank == response_result["rank"]
143+
assert transcript_result.text == response_result["text"]
144+
145+
for response_timestamp, transcript_timestamp in zip(
146+
response_result["timestamps"], transcript_result.timestamps
147+
):
148+
assert transcript_timestamp.start == response_timestamp["start"]
149+
assert transcript_timestamp.end == response_timestamp["end"]

0 commit comments

Comments
 (0)