Skip to content

Commit 0453022

Browse files
dmccrystals0h3yl
andcommitted
feat: add content_safety functionality
Co-authored-by: Soheyl <[email protected]> GitOrigin-RevId: ade523b5a3e103d0f57fd1bdc424d80b429214d1
1 parent 57d2335 commit 0453022

File tree

7 files changed

+422
-27
lines changed

7 files changed

+422
-27
lines changed

README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,54 @@ config=aai.TranscriptionConfig(
308308
)
309309
```
310310

311+
</details>
312+
<details>
313+
<summary>Detect Sensitive Content in a Transcript</summary>
314+
315+
```python
316+
import assemblyai as aai
317+
318+
transcriber = aai.Transcriber()
319+
transcript = transcriber.transcribe(
320+
"https://example.org/audio.mp3",
321+
config=aai.TranscriptionConfig(content_safety=True)
322+
)
323+
324+
325+
# Get the parts of the transcript which were flagged as sensitive
326+
for result in transcript.content_safety_labels.results:
327+
print(result.text) # sensitive text snippet
328+
print(result.timestamp.start)
329+
print(result.timestamp.end)
330+
331+
for label in result.labels:
332+
print(label.label) # content safety category
333+
print(label.confidence) # model's confidence that the text is in this category
334+
print(label.severity) # severity of the text in relation to the category
335+
336+
# Get the confidence of the most common labels in relation to the entire audio file
337+
for label, confidence in transcript.content_safety_labels.summary.items():
338+
print(f"{confidence * 100}% confident that the audio contains {label}")
339+
340+
# Get the overall severity of the most common labels in relation to the entire audio file
341+
for label, severity_confidence in transcript.content_safety_labels.severity_score_summary.items():
342+
print(f"{severity_confidence.low * 100}% confident that the audio contains low-severity {label}")
343+
print(f"{severity_confidence.medium * 100}% confident that the audio contains mid-severity {label}")
344+
print(f"{severity_confidence.high * 100}% confident that the audio contains high-severity {label}")
345+
346+
```
347+
348+
[Read more about the content safety categories.](https://www.assemblyai.com/docs/Models/content_moderation#all-labels-supported-by-the-model)
349+
350+
By default, the content safety model will only include labels with a confidence greater than 0.5 (50%). To change this, pass `content_safety_confidence` (as an integer percentage between 25 and 100, inclusive) to the `TranscriptionConfig`:
351+
352+
```python
353+
config=aai.TranscriptionConfig(
354+
content_safety=True,
355+
content_safety_confidence=80, # only include labels with a confidence greater than 80%
356+
)
357+
```
358+
311359
</details>
312360

313361
---

assemblyai/api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def create_transcript(
3434
by_alias=True,
3535
),
3636
)
37-
3837
if response.status_code != httpx.codes.ok:
3938
raise types.TranscriptError(
4039
f"failed to transcript url {request.audio_url}: {_get_error_message(response)}"

assemblyai/transcriber.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ def summary(self) -> Optional[str]:
214214
def chapters(self) -> Optional[List[types.Chapter]]:
215215
return self._impl.transcript.chapters
216216

217+
@property
218+
def content_safety_labels(self) -> Optional[types.ContentSafetyResponse]:
219+
return self._impl.transcript.content_safety_labels
220+
217221
@property
218222
def status(self) -> types.TranscriptStatus:
219223
"The current status of the transcript"

assemblyai/types.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,11 @@ class RawTranscriptionConfig(BaseModel):
339339
speakers_expected: Optional[int]
340340
"The number of speakers you expect to be in your audio file."
341341

342-
# content_safety: bool = False
343-
# "Enable Content Safety Detection."
342+
content_safety: Optional[bool]
343+
"Enable Content Safety Detection."
344+
345+
content_safety_confidence: Optional[int]
346+
"The minimum confidence level for a content safety label to be produced."
344347

345348
# iab_categories: bool = False
346349
# "Enable Topic Detection."
@@ -410,7 +413,8 @@ def __init__(
410413
redact_pii_sub: Optional[PIISubstitutionPolicy] = None,
411414
speaker_labels: Optional[bool] = None,
412415
speakers_expected: Optional[int] = None,
413-
# content_safety: bool = False,
416+
content_safety: Optional[bool] = None,
417+
content_safety_confidence: Optional[int] = None,
414418
# iab_categories: bool = False,
415419
custom_spelling: Optional[Dict[str, Union[str, Sequence[str]]]] = None,
416420
disfluencies: Optional[bool] = None,
@@ -486,7 +490,7 @@ def __init__(
486490
redact_pii_sub,
487491
)
488492
self.set_speaker_diarization(speaker_labels, speakers_expected)
489-
# self.content_safety = content_safety
493+
self.set_content_safety(content_safety, content_safety_confidence)
490494
# self.iab_categories = iab_categories
491495
self.set_custom_spelling(custom_spelling, override=True)
492496
self.disfluencies = disfluencies
@@ -644,17 +648,51 @@ def speakers_expected(self) -> Optional[int]:
644648

645649
return self._raw_transcription_config.speakers_expected
646650

647-
# @property
648-
# def content_safety(self) -> bool:
649-
# "Returns the status of the Content Safety feature."
651+
@property
652+
def content_safety(self) -> Optional[bool]:
653+
"Returns the status of the Content Safety feature."
654+
655+
return self._raw_transcription_config.content_safety
656+
657+
@property
658+
def content_safety_confidence(self) -> Optional[int]:
659+
"The minimum confidence level for a content safety label to be produced. Used in combination with the `content_safety` parameter."
650660

651-
# return self._raw_transcription_config.content_safety
661+
return self._raw_transcription_config.content_safety_confidence
652662

653-
# @content_safety.setter
654-
# def content_safety(self, enable: bool) -> None:
655-
# "Enable Content Safety feature."
663+
def set_content_safety(
664+
self,
665+
enable: Optional[bool] = True,
666+
content_safety_confidence: Optional[int] = None,
667+
) -> Self:
668+
"""Enable Content Safety feature.
669+
670+
Args:
671+
`enable`: Whether or not to enable the Content Safety feature.
672+
`content_safety_confidence`: The minimum confidence level for a content safety label to be produced.
656673
657-
# self._raw_transcription_config.content_safety = enable
674+
Raises:
675+
`ValueError`: Raised if `content_safety_confidence` is not between 25 and 100 (inclusive).
676+
"""
677+
678+
if not enable:
679+
self._raw_transcription_config.content_safety = None
680+
self._raw_transcription_config.content_safety_confidence = None
681+
return self
682+
683+
if content_safety_confidence is not None and (
684+
content_safety_confidence < 25 or content_safety_confidence > 100
685+
):
686+
raise ValueError(
687+
"content_safety_confidence must be between 25 and 100 (inclusive)."
688+
)
689+
690+
self._raw_transcription_config.content_safety = enable
691+
self._raw_transcription_config.content_safety_confidence = (
692+
content_safety_confidence
693+
)
694+
695+
return self
658696

659697
# @property
660698
# def iab_categories(self) -> bool:
@@ -1162,7 +1200,7 @@ class AutohighlightResponse(BaseModel):
11621200
class ContentSafetyLabelResult(BaseModel):
11631201
label: ContentSafetyLabel
11641202
confidence: float
1165-
severity: float
1203+
severity: Optional[float]
11661204

11671205

11681206
class ContentSafetySeverityScore(BaseModel):
@@ -1180,8 +1218,10 @@ class ContentSafetyResult(BaseModel):
11801218
class ContentSafetyResponse(BaseModel):
11811219
status: StatusResult
11821220
results: Optional[List[ContentSafetyResult]]
1183-
summary: Optional[Dict[str, float]]
1184-
severity_score_summary: Optional[Dict[str, ContentSafetySeverityScore]]
1221+
summary: Optional[Dict[ContentSafetyLabel, float]]
1222+
severity_score_summary: Optional[
1223+
Dict[ContentSafetyLabel, ContentSafetySeverityScore]
1224+
]
11851225

11861226

11871227
class IABLabelResult(BaseModel):
@@ -1308,8 +1348,11 @@ class BaseTranscript(BaseModel):
13081348
speaker_labels: Optional[bool]
13091349
"Enable Speaker Diarization."
13101350

1311-
# content_safety: bool = False
1312-
# "Enable Content Safety Detection."
1351+
content_safety: Optional[bool]
1352+
"Enable Content Safety Detection."
1353+
1354+
content_safety_confidence: Optional[int]
1355+
"The minimum confidence level for a content safety label to be produced."
13131356

13141357
# iab_categories: bool = False
13151358
# "Enable Topic Detection."
@@ -1401,8 +1444,8 @@ class TranscriptResponse(BaseTranscript):
14011444
# auto_highlights_result: Optional[AutohighlightResponse] = None
14021445
# "The list of results when enabling Automatic Transcript Highlights"
14031446

1404-
# content_safety_labels: Optional[ContentSafetyResponse] = None
1405-
# "The list of results when Content Safety is enabled"
1447+
content_safety_labels: Optional[ContentSafetyResponse]
1448+
"The list of results when Content Safety is enabled"
14061449

14071450
# iab_categories_result: Optional[IABResponse] = None
14081451
# "The list of results when Topic Detection is enabled"
@@ -1416,15 +1459,18 @@ class TranscriptResponse(BaseTranscript):
14161459
# entities: Optional[List[Entity]] = None
14171460
# "When Entity Detection is enabled, the list of detected Entities"
14181461

1419-
# def __init__(self, **data: Any):
1420-
# # cleanup the response before creating the object
1421-
# if data.get("iab_categories_result") == {}:
1422-
# data["iab_categories_result"] = None
1462+
def __init__(self, **data: Any):
1463+
# cleanup the response before creating the object
1464+
# if data.get("iab_categories_result") == {}:
1465+
# data["iab_categories_result"] = None
14231466

1424-
# if data.get("content_safety_labels") == {}:
1425-
# data["content_safety_labels"] = None
1467+
if data.get("content_safety_labels") == {} or (
1468+
not data.get("content_safety")
1469+
and data.get("content_safety_labels", {}).get("status") == "unavailable"
1470+
):
1471+
data["content_safety_labels"] = None
14261472

1427-
# super().__init__(**data)
1473+
super().__init__(**data)
14281474

14291475

14301476
class LemurModel(str, Enum):

tests/unit/factories.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from AssemblyAI's API.
44
"""
55

6+
from enum import Enum
67
from functools import partial
78
from typing import Any, Callable, Dict
89

@@ -13,6 +14,14 @@
1314
from assemblyai import types
1415

1516

17+
class TimestampFactory(factory.Factory):
18+
class Meta:
19+
model = aai.Timestamp
20+
21+
start = factory.Faker("pyint")
22+
end = factory.Faker("pyint")
23+
24+
1625
class WordFactory(factory.Factory):
1726
class Meta:
1827
model = aai.Word
@@ -234,6 +243,8 @@ def convert_dict_from_stub(stub: factory.base.StubObject) -> Dict[str, Any]:
234243
if stub_is_list(value)
235244
else convert_dict_from_stub(value)
236245
)
246+
elif isinstance(value, Enum):
247+
stub_dict[key] = value.value
237248
return stub_dict
238249

239250
def dict_factory(f, **kwargs):

0 commit comments

Comments
 (0)