Skip to content

Commit 1b1f91a

Browse files
dmccrystals0h3yl
andcommitted
feat: add entity_detection functionality
Co-authored-by: Soheyl <[email protected]> GitOrigin-RevId: 45be764b2c352a90e59bc41cff255195ddfa07f5
1 parent 9042400 commit 1b1f91a

File tree

4 files changed

+159
-16
lines changed

4 files changed

+159
-16
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,27 @@ for sentiment_result in transcript.sentiment_analysis_results:
391391

392392
[Read more about sentiment analysis here.](https://www.assemblyai.com/docs/Models/sentiment_analysis)
393393

394+
</details>
395+
<details>
396+
<summary>Identify Entities in a Transcript</summary>
397+
398+
```python
399+
import assemblyai as aai
400+
401+
transcriber = aai.Transcriber()
402+
transcript = transcriber.transcribe(
403+
"https://example.org/audio.mp3",
404+
config=aai.TranscriptionConfig(entity_detection=True)
405+
)
406+
407+
for entity in transcript.entities:
408+
print(entity.text) # i.e. "Dan Gilbert"
409+
print(entity.type) # i.e. EntityType.person
410+
print(f"Timestamp: {entity.start} - {entity.end}")
411+
```
412+
413+
[Read more about entity detection here.](https://www.assemblyai.com/docs/Models/entity_detection)
414+
394415
</details>
395416

396417
---

assemblyai/transcriber.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ def content_safety_labels(self) -> Optional[types.ContentSafetyResponse]:
222222
def sentiment_analysis_results(self) -> Optional[List[types.Sentiment]]:
223223
return self._impl.transcript.sentiment_analysis_results
224224

225+
@property
226+
def entities(self) -> Optional[List[types.Entity]]:
227+
return self._impl.transcript.entities
228+
225229
@property
226230
def status(self) -> types.TranscriptStatus:
227231
"The current status of the transcript"

assemblyai/types.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ class RawTranscriptionConfig(BaseModel):
360360
auto_chapters: Optional[bool]
361361
"Enable Auto Chapters."
362362

363-
# entity_detection: bool = False
364-
# "Enable Entity Detection."
363+
entity_detection: Optional[bool]
364+
"Enable Entity Detection."
365365

366366
summarization: Optional[bool]
367367
"Enable Summarization"
@@ -420,7 +420,7 @@ def __init__(
420420
disfluencies: Optional[bool] = None,
421421
sentiment_analysis: Optional[bool] = None,
422422
auto_chapters: Optional[bool] = None,
423-
# entity_detection: bool = False,
423+
entity_detection: Optional[bool] = None,
424424
summarization: Optional[bool] = None,
425425
summary_model: Optional[SummarizationModel] = None,
426426
summary_type: Optional[SummarizationType] = None,
@@ -496,7 +496,7 @@ def __init__(
496496
self.disfluencies = disfluencies
497497
self.sentiment_analysis = sentiment_analysis
498498
self.auto_chapters = auto_chapters
499-
# self.entity_detection = entity_detection
499+
self.entity_detection = entity_detection
500500
self.set_summarize(
501501
summarization,
502502
summary_model,
@@ -763,17 +763,17 @@ def auto_chapters(self, enable: Optional[bool]) -> None:
763763

764764
self._raw_transcription_config.auto_chapters = enable
765765

766-
# @property
767-
# def entity_detection(self) -> bool:
768-
# "Returns whether Entity Detection feature is enabled or not."
766+
@property
767+
def entity_detection(self) -> bool:
768+
"Returns whether Entity Detection feature is enabled or not."
769769

770-
# return self._raw_transcription_config.entity_detection
770+
return self._raw_transcription_config.entity_detection
771771

772-
# @entity_detection.setter
773-
# def entity_detection(self, enable: bool) -> None:
774-
# "Enable Entity Detection."
772+
@entity_detection.setter
773+
def entity_detection(self, enable: Optional[bool]) -> None:
774+
"Enable Entity Detection."
775775

776-
# self._raw_transcription_config.entity_detection = enable
776+
self._raw_transcription_config.entity_detection = enable
777777

778778
@property
779779
def summarization(self) -> Optional[bool]:
@@ -1370,8 +1370,8 @@ class BaseTranscript(BaseModel):
13701370
auto_chapters: Optional[bool]
13711371
"Enable Auto Chapters."
13721372

1373-
# entity_detection: bool = False
1374-
# "Enable Entity Detection."
1373+
entity_detection: Optional[bool]
1374+
"Enable Entity Detection."
13751375

13761376
summarization: Optional[bool]
13771377
"Enable Summarization"
@@ -1457,8 +1457,8 @@ class TranscriptResponse(BaseTranscript):
14571457
sentiment_analysis_results: Optional[List[Sentiment]]
14581458
"When Sentiment Analysis is enabled, the list of Sentiment Analysis results"
14591459

1460-
# entities: Optional[List[Entity]] = None
1461-
# "When Entity Detection is enabled, the list of detected Entities"
1460+
entities: Optional[List[Entity]]
1461+
"When Entity Detection is enabled, the list of detected Entities"
14621462

14631463
def __init__(self, **data: Any):
14641464
# cleanup the response before creating the object

tests/unit/test_entity_detection.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import json
2+
from typing import Any, Dict, Tuple
3+
4+
import factory
5+
import httpx
6+
from pytest_httpx import HTTPXMock
7+
8+
import assemblyai as aai
9+
from tests.unit import factories
10+
11+
aai.settings.api_key = "test"
12+
13+
14+
class EntityFactory(factory.Factory):
15+
class Meta:
16+
model = aai.types.Entity
17+
18+
entity_type = factory.Faker("enum", enum_cls=aai.types.EntityType)
19+
text = factory.Faker("sentence")
20+
start = factory.Faker("pyint")
21+
end = factory.Faker("pyint")
22+
23+
24+
class EntityDetectionResponseFactory(factories.TranscriptCompletedResponseFactory):
25+
entities = factory.List([factory.SubFactory(EntityFactory)])
26+
27+
28+
def __submit_mock_request(
29+
httpx_mock: HTTPXMock,
30+
mock_response: Dict[str, Any],
31+
config: aai.TranscriptionConfig,
32+
) -> Tuple[Dict[str, Any], aai.Transcript]:
33+
"""
34+
Helper function to abstract mock transcriber calls with given `TranscriptionConfig`,
35+
and perform some common assertions.
36+
"""
37+
38+
mock_transcript_id = mock_response.get("id", "mock_id")
39+
40+
# Mock initial submission response (transcript is processing)
41+
mock_processing_response = factories.generate_dict_factory(
42+
factories.TranscriptProcessingResponseFactory
43+
)()
44+
45+
httpx_mock.add_response(
46+
url=f"{aai.settings.base_url}/transcript",
47+
status_code=httpx.codes.OK,
48+
method="POST",
49+
json={
50+
**mock_processing_response,
51+
"id": mock_transcript_id, # inject ID from main mock response
52+
},
53+
)
54+
55+
# Mock polling-for-completeness response, with completed transcript
56+
httpx_mock.add_response(
57+
url=f"{aai.settings.base_url}/transcript/{mock_transcript_id}",
58+
status_code=httpx.codes.OK,
59+
method="GET",
60+
json=mock_response,
61+
)
62+
63+
# == Make API request via SDK ==
64+
transcript = aai.Transcriber().transcribe(
65+
data="https://example.org/audio.wav",
66+
config=config,
67+
)
68+
69+
# Check that submission and polling requests were made
70+
assert len(httpx_mock.get_requests()) == 2
71+
72+
# Extract body of initial submission request
73+
request = httpx_mock.get_requests()[0]
74+
request_body = json.loads(request.content.decode())
75+
76+
return request_body, transcript
77+
78+
79+
def test_entity_detection_disabled_by_default(httpx_mock: HTTPXMock):
80+
"""
81+
Tests that excluding `entity_detection` from the `TranscriptionConfig` will
82+
result in the default behavior of it being excluded from the request body
83+
"""
84+
request_body, transcript = __submit_mock_request(
85+
httpx_mock,
86+
mock_response=factories.generate_dict_factory(
87+
factories.TranscriptCompletedResponseFactory
88+
)(),
89+
config=aai.TranscriptionConfig(),
90+
)
91+
assert request_body.get("entity_detection") is None
92+
assert transcript.entities is None
93+
94+
95+
def test_entity_detection_enabled(httpx_mock: HTTPXMock):
96+
"""
97+
Tests that including `entity_detection=True` in the `TranscriptionConfig`
98+
will result in `entity_detection=True` in the request body, and that the
99+
response is properly parsed into a `Transcript` object
100+
"""
101+
mock_response = factories.generate_dict_factory(EntityDetectionResponseFactory)()
102+
request_body, transcript = __submit_mock_request(
103+
httpx_mock,
104+
mock_response=mock_response,
105+
config=aai.TranscriptionConfig(entity_detection=True),
106+
)
107+
108+
# Check that request body was properly defined
109+
assert request_body.get("entity_detection") == True
110+
111+
# Check that transcript was properly parsed from JSON response
112+
assert transcript.error is None
113+
assert transcript.entities is not None
114+
assert len(transcript.entities) > 0
115+
assert len(transcript.entities) == len(mock_response["entities"])
116+
117+
for entity in transcript.entities:
118+
assert len(entity.text.strip()) > 0

0 commit comments

Comments
 (0)