Skip to content

Commit 6bf1c58

Browse files
dmccrystals0h3yl
authored andcommitted
feat: add iab_categories (topic detection) functionality
GitOrigin-RevId: e03e3eab65e8a4a1f72cacca8725aeb0af0cda0c
1 parent b675bf6 commit 6bf1c58

File tree

5 files changed

+226
-21
lines changed

5 files changed

+226
-21
lines changed

README.md

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,34 @@ for entity in transcript.entities:
412412

413413
[Read more about entity detection here.](https://www.assemblyai.com/docs/Models/entity_detection)
414414

415+
</details>
416+
<details>
417+
<summary>Detect Topics in a Transcript (IAB Classification)</summary>
418+
419+
```python
420+
import assemblyai as aai
421+
422+
transcriber = aai.Transcriber()
423+
transcript = transcriber.transcribe(
424+
"https://example.org/audio.mp3",
425+
config=aai.TranscriptionConfig(iab_categories=True)
426+
)
427+
428+
# Get the parts of the transcript that were tagged with topics
429+
for result in transcript.iab_categories.results:
430+
print(result.text)
431+
print(f"Timestamp: {result.timestamp.start} - {result.timestamp.end}")
432+
for label in result.labels:
433+
print(label.label) # topic
434+
print(label.relevance) # how relevant the label is for the portion of text
435+
436+
# Get a summary of all topics in the transcript
437+
for label, relevance in transcript.iab_categories.summary.items():
438+
print(f"Audio is {relevance * 100}% relevant to {label}")
439+
```
440+
441+
[Read more about IAB classification here.](https://www.assemblyai.com/docs/Models/iab_classification)
442+
415443
</details>
416444
<details>
417445
<summary>Identify Important Words and Phrases in a Transcript</summary>
@@ -426,11 +454,12 @@ transcript = transcriber.transcribe(
426454
)
427455

428456
for result in transcript.auto_highlights_result.results:
429-
print(result.text) # the important phrase
430-
print(result.rank) # relevancy of the phrase
431-
print(result.count) # number of instances of the phrase
457+
print(result.text) # the important phrase
458+
print(result.rank) # relevancy of the phrase
459+
print(result.count) # number of instances of the phrase
432460
for timestamp in result.timestamps:
433461
print(f"Timestamp: {timestamp.start} - {timestamp.end}")
462+
434463
```
435464

436465
[Read more about auto highlights here.](https://www.assemblyai.com/docs/Models/key_phrases)

assemblyai/transcriber.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ def sentiment_analysis_results(self) -> Optional[List[types.Sentiment]]:
226226
def entities(self) -> Optional[List[types.Entity]]:
227227
return self._impl.transcript.entities
228228

229+
@property
230+
def iab_categories(self) -> Optional[types.IABResponse]:
231+
return self._impl.transcript.iab_categories_result
232+
229233
@property
230234
def auto_highlights_result(self) -> Optional[types.AutohighlightResponse]:
231235
return self._impl.transcript.auto_highlights_result

assemblyai/types.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@ class RawTranscriptionConfig(BaseModel):
345345
content_safety_confidence: Optional[int]
346346
"The minimum confidence level for a content safety label to be produced."
347347

348-
# iab_categories: bool = False
349-
# "Enable Topic Detection."
348+
iab_categories: Optional[bool]
349+
"Enable Topic Detection."
350350

351351
custom_spelling: Optional[List[Dict[str, List[str]]]]
352352
"Customize how words are spelled and formatted using to and from values"
@@ -415,7 +415,7 @@ def __init__(
415415
speakers_expected: Optional[int] = None,
416416
content_safety: Optional[bool] = None,
417417
content_safety_confidence: Optional[int] = None,
418-
# iab_categories: bool = False,
418+
iab_categories: Optional[bool] = None,
419419
custom_spelling: Optional[Dict[str, Union[str, Sequence[str]]]] = None,
420420
disfluencies: Optional[bool] = None,
421421
sentiment_analysis: Optional[bool] = None,
@@ -491,7 +491,7 @@ def __init__(
491491
)
492492
self.set_speaker_diarization(speaker_labels, speakers_expected)
493493
self.set_content_safety(content_safety, content_safety_confidence)
494-
# self.iab_categories = iab_categories
494+
self.iab_categories = iab_categories
495495
self.set_custom_spelling(custom_spelling, override=True)
496496
self.disfluencies = disfluencies
497497
self.sentiment_analysis = sentiment_analysis
@@ -694,17 +694,17 @@ def set_content_safety(
694694

695695
return self
696696

697-
# @property
698-
# def iab_categories(self) -> bool:
699-
# "Returns the status of the Topic Detection feature."
697+
@property
698+
def iab_categories(self) -> Optional[bool]:
699+
"Returns the status of the Topic Detection feature."
700700

701-
# return self._raw_transcription_config.iab_categories
701+
return self._raw_transcription_config.iab_categories
702702

703-
# @iab_categories.setter
704-
# def iab_categories(self, enable: bool) -> None:
705-
# "Enable Topic Detection feature."
703+
@iab_categories.setter
704+
def iab_categories(self, enable: Optional[bool]) -> None:
705+
"Enable Topic Detection feature."
706706

707-
# self._raw_transcription_config.iab_categories = enable
707+
self._raw_transcription_config.iab_categories = enable
708708

709709
@property
710710
def custom_spelling(self) -> Optional[Dict[str, List[str]]]:
@@ -1355,8 +1355,8 @@ class BaseTranscript(BaseModel):
13551355
content_safety_confidence: Optional[int]
13561356
"The minimum confidence level for a content safety label to be produced."
13571357

1358-
# iab_categories: bool = False
1359-
# "Enable Topic Detection."
1358+
iab_categories: Optional[bool]
1359+
"Enable Topic Detection."
13601360

13611361
custom_spelling: Optional[List[Dict[str, Union[str, List[str]]]]]
13621362
"Customize how words are spelled and formatted using to and from values"
@@ -1448,8 +1448,8 @@ class TranscriptResponse(BaseTranscript):
14481448
content_safety_labels: Optional[ContentSafetyResponse]
14491449
"The list of results when Content Safety is enabled"
14501450

1451-
# iab_categories_result: Optional[IABResponse] = None
1452-
# "The list of results when Topic Detection is enabled"
1451+
iab_categories_result: Optional[IABResponse]
1452+
"The list of results when Topic Detection is enabled"
14531453

14541454
chapters: Optional[List[Chapter]]
14551455
"When Auto Chapters is enabled, the list of Auto Chapters results"
@@ -1462,8 +1462,11 @@ class TranscriptResponse(BaseTranscript):
14621462

14631463
def __init__(self, **data: Any):
14641464
# cleanup the response before creating the object
1465-
# if data.get("iab_categories_result") == {}:
1466-
# data["iab_categories_result"] = None
1465+
if data.get("iab_categories_result") == {} or (
1466+
not data.get("iab_categories")
1467+
and data.get("iab_categories_result", {}).get("status") == "unavailable"
1468+
):
1469+
data["iab_categories_result"] = None
14671470

14681471
if data.get("content_safety_labels") == {} or (
14691472
not data.get("content_safety")

tests/unit/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ py_test(
2929
"test_content_safety.py",
3030
"test_domains.py",
3131
"test_entity_detection.py",
32+
"test_iab_categories.py",
3233
"test_lemur.py",
3334
"test_sentiment_analysis.py",
3435
"test_summarization.py",

tests/unit/test_iab_categories.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import json
2+
from typing import Any, Dict, Tuple
3+
4+
import factory
5+
import httpx
6+
from pytest_httpx import HTTPXMock
7+
8+
import assemblyai as aai
9+
from tests.unit import factories
10+
11+
aai.settings.api_key = "test"
12+
13+
14+
class IABLabelResultFactory(factory.Factory):
15+
class Meta:
16+
model = aai.types.IABLabelResult
17+
18+
relevance = factory.Faker("pyfloat", min_value=0, max_value=1)
19+
label = factory.Faker("word")
20+
21+
22+
class IABResultFactory(factory.Factory):
23+
class Meta:
24+
model = aai.types.IABResult
25+
26+
text = factory.Faker("sentence")
27+
labels = factory.List([factory.SubFactory(IABLabelResultFactory)])
28+
timestamp = factory.SubFactory(factories.TimestampFactory)
29+
30+
31+
class IABResponseFactory(factory.Factory):
32+
class Meta:
33+
model = aai.types.IABResponse
34+
35+
status = aai.types.StatusResult.success.value
36+
results = factory.List([factory.SubFactory(IABResultFactory)])
37+
summary = factory.Dict(
38+
{
39+
"Automotive>AutoType>ConceptCars": factory.Faker(
40+
"pyfloat", min_value=0, max_value=1
41+
)
42+
}
43+
)
44+
45+
46+
class IABCategoriesResponseFactory(factories.TranscriptCompletedResponseFactory):
47+
iab_categories_result = factory.SubFactory(IABResponseFactory)
48+
49+
50+
def __submit_mock_request(
51+
httpx_mock: HTTPXMock,
52+
mock_response: Dict[str, Any],
53+
config: aai.TranscriptionConfig,
54+
) -> Tuple[Dict[str, Any], aai.Transcript]:
55+
"""
56+
Helper function to abstract mock transcriber calls with given `TranscriptionConfig`,
57+
and perform some common assertions.
58+
"""
59+
60+
mock_transcript_id = mock_response.get("id", "mock_id")
61+
62+
# Mock initial submission response (transcript is processing)
63+
mock_processing_response = factories.generate_dict_factory(
64+
factories.TranscriptProcessingResponseFactory
65+
)()
66+
67+
httpx_mock.add_response(
68+
url=f"{aai.settings.base_url}/transcript",
69+
status_code=httpx.codes.OK,
70+
method="POST",
71+
json={
72+
**mock_processing_response,
73+
"id": mock_transcript_id, # inject ID from main mock response
74+
},
75+
)
76+
77+
# Mock polling-for-completeness response, with completed transcript
78+
httpx_mock.add_response(
79+
url=f"{aai.settings.base_url}/transcript/{mock_transcript_id}",
80+
status_code=httpx.codes.OK,
81+
method="GET",
82+
json=mock_response,
83+
)
84+
85+
# == Make API request via SDK ==
86+
transcript = aai.Transcriber().transcribe(
87+
data="https://example.org/audio.wav",
88+
config=config,
89+
)
90+
91+
# Check that submission and polling requests were made
92+
assert len(httpx_mock.get_requests()) == 2
93+
94+
# Extract body of initial submission request
95+
request = httpx_mock.get_requests()[0]
96+
request_body = json.loads(request.content.decode())
97+
98+
return request_body, transcript
99+
100+
101+
def test_iab_categories_disabled_by_default(httpx_mock: HTTPXMock):
102+
"""
103+
Tests that excluding `iab_categories` from the `TranscriptionConfig` will
104+
result in the default behavior of it being excluded from the request body
105+
"""
106+
107+
request_body, transcript = __submit_mock_request(
108+
httpx_mock,
109+
mock_response=factories.generate_dict_factory(
110+
factories.TranscriptCompletedResponseFactory
111+
)(),
112+
config=aai.TranscriptionConfig(),
113+
)
114+
assert request_body.get("iab_categories") is None
115+
assert transcript.iab_categories is None
116+
117+
118+
def test_iab_categories_enabled(httpx_mock: HTTPXMock):
119+
"""
120+
Tests that including `iab_categories=True` in the `TranscriptionConfig` will
121+
result in `iab_categories` being included in the request body, and that
122+
the response will be properly parsed into the `Transcript` object
123+
"""
124+
125+
mock_response = factories.generate_dict_factory(IABCategoriesResponseFactory)()
126+
127+
request_body, transcript = __submit_mock_request(
128+
httpx_mock,
129+
mock_response=mock_response,
130+
config=aai.TranscriptionConfig(iab_categories=True),
131+
)
132+
133+
assert request_body.get("iab_categories") is True
134+
135+
assert transcript.error is None
136+
137+
assert transcript.iab_categories is not None
138+
assert transcript.iab_categories.status == mock_response.get(
139+
"iab_categories_result", {}
140+
).get("status")
141+
142+
# Validate results
143+
response_results = mock_response.get("iab_categories_result", {}).get("results", [])
144+
transcript_results = transcript.iab_categories.results
145+
146+
assert transcript_results is not None
147+
assert len(transcript_results) == len(response_results)
148+
assert len(transcript_results) > 0
149+
150+
for response_result, transcript_result in zip(response_results, transcript_results):
151+
assert transcript_result.text == response_result.get("text")
152+
assert len(transcript_result.text) > 0
153+
154+
assert len(transcript_result.labels) > 0
155+
assert len(transcript_result.labels) == len(response_result.get("labels", []))
156+
for response_label, transcript_label in zip(
157+
response_result.get("labels", []), transcript_result.labels
158+
):
159+
assert transcript_label.relevance == response_label.get("relevance")
160+
assert transcript_label.label == response_label.get("label")
161+
162+
# Validate summary
163+
response_summary = mock_response.get("iab_categories_result", {}).get("summary", {})
164+
transcript_summary = transcript.iab_categories.summary
165+
166+
assert transcript_summary is not None
167+
assert len(transcript_summary) > 0
168+
assert transcript_summary == response_summary

0 commit comments

Comments
 (0)