diff --git a/README.md b/README.md
index a6a03cc..754eb2a 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,7 @@
---
+
[](https://github.com/AssemblyAI/assemblyai-python-sdk/actions/workflows/test.yml)
[](https://github.com/AssemblyAI/assemblyai-python-sdk/blob/master/LICENSE)
[](https://badge.fury.io/py/assemblyai)
@@ -26,14 +27,12 @@ With a single API call, get access to AI models built on the latest AI breakthro
- [Playgrounds](#playgrounds)
- [Advanced](#advanced-todo)
-
# Documentation
Visit our [AssemblyAI API Documentation](https://www.assemblyai.com/docs) to get an overview of our models!
# Quick Start
-
## Installation
```bash
@@ -66,6 +65,7 @@ transcript = transcriber.transcribe("./my-local-audio-file.wav")
print(transcript.text)
```
+
@@ -79,6 +79,7 @@ transcript = transcriber.transcribe("https://example.org/audio.mp3")
print(transcript.text)
```
+
@@ -96,6 +97,7 @@ print(transcript.export_subtitles_srt())
# in VTT format
print(transcript.export_subtitles_vtt())
```
+
@@ -115,6 +117,7 @@ paragraphs = transcript.get_paragraphs()
for paragraph in paragraphs:
print(paragraph.text)
```
+
@@ -131,6 +134,7 @@ matches = transcript.word_search(["price", "product"])
for match in matches:
print(f"Found '{match.text}' {match.count} times in the transcript")
```
+
@@ -152,9 +156,40 @@ transcript = transcriber.transcribe("https://example.org/audio.mp3", config)
print(transcript.text)
```
+
+
+
+
+ Summarize the content of a transcript
+
+```python
+import assemblyai as aai
+
+transcriber = aai.Transcriber()
+transcript = transcriber.transcribe(
+ "https://example.org/audio.mp3",
+ config=aai.TranscriptionConfig(summarize=True)
+)
+
+print(transcript.summary)
+```
+
+By default, the summarization model will be `informative` and the summarization type will be `bullets`. [Read more about summarization models and types here](https://www.assemblyai.com/docs/Models/summarization#types-and-models).
+
+To change the model and/or type, pass additional parameters to the `TranscriptionConfig`:
+
+```python
+config=aai.TranscriptionConfig(
+ summarize=True,
+ summary_model=aai.SummarizationModel.catchy,
+ summary_type=aai.Summarizationtype.headline
+)
+```
+
---
+
### **LeMUR Examples**
@@ -175,6 +210,7 @@ summary = transcript_group.lemur.summarize(context="Customers asking for cars",
print(summary)
```
+
@@ -195,6 +231,7 @@ feedback = transcript_group.lemur.ask_coach(context="Who was the best interviewe
print(feedback)
```
+
@@ -218,6 +255,7 @@ for result in result:
print(f"Question: {result.question}")
print(f"Answer: {result.answer}")
```
+
---
@@ -247,8 +285,8 @@ config.set_pii_redact(
transcriber = aai.Transcriber()
transcript = transcriber.transcribe("https://example.org/audio.mp3", config)
```
-
+
---
diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py
index 4d37be4..5a5ad23 100644
--- a/assemblyai/transcriber.py
+++ b/assemblyai/transcriber.py
@@ -204,6 +204,12 @@ def text(self) -> Optional[str]:
return self._impl.transcript.text
+ @property
+ def summary(self) -> Optional[str]:
+ "The summarization of the transcript"
+
+ return self._impl.transcript.summary
+
@property
def status(self) -> types.TranscriptStatus:
"The current status of the transcript"
diff --git a/assemblyai/types.py b/assemblyai/types.py
index 7f64907..647947f 100644
--- a/assemblyai/types.py
+++ b/assemblyai/types.py
@@ -1004,6 +1004,16 @@ def set_summarize(
return self
+ # Validate that required parameters are also set
+ if self._raw_transcription_config.punctuate == False:
+ raise ValueError(
+ "If `summarization` is enabled, then `punctuate` must not be disabled"
+ )
+ if self._raw_transcription_config.format_text == False:
+ raise ValueError(
+ "If `summarization` is enabled, then `format_text` must not be disabled"
+ )
+
self._raw_transcription_config.summarization = True
self._raw_transcription_config.summary_model = model
self._raw_transcription_config.summary_type = type
@@ -1379,6 +1389,9 @@ class TranscriptResponse(BaseTranscript):
webhook_auth: Optional[bool]
"Whether the webhook was sent with an HTTP authentication header"
+ summary: Optional[str]
+ "The summarization of the transcript"
+
# auto_highlights_result: Optional[AutohighlightResponse] = None
# "The list of results when enabling Automatic Transcript Highlights"
diff --git a/setup.py b/setup.py
index 0edeff5..6cffa7b 100644
--- a/setup.py
+++ b/setup.py
@@ -7,7 +7,7 @@
setup(
name="assemblyai",
- version="0.5.1",
+ version="0.6.0",
description="AssemblyAI Python SDK",
author="AssemblyAI",
author_email="engineering.sdk@assemblyai.com",
diff --git a/tests/unit/factories.py b/tests/unit/factories.py
index c6e200f..d7fbdd8 100644
--- a/tests/unit/factories.py
+++ b/tests/unit/factories.py
@@ -200,7 +200,7 @@ class Meta:
audio_duration = factory.Faker("pyint")
-def generate_dict_factory(f: factory.Factory) -> Callable[[None], Dict[str, Any]]:
+def generate_dict_factory(f: factory.Factory) -> Callable[[], Dict[str, Any]]:
"""
Creates a dict factory from the given *Factory class.
diff --git a/tests/unit/test_summarization.py b/tests/unit/test_summarization.py
new file mode 100644
index 0000000..6713e1b
--- /dev/null
+++ b/tests/unit/test_summarization.py
@@ -0,0 +1,137 @@
+import json
+from typing import Any, Dict
+
+import httpx
+import pytest
+from pytest_httpx import HTTPXMock
+
+import assemblyai as aai
+from tests.unit import factories
+
+aai.settings.api_key = "test"
+
+
+def __submit_request(httpx_mock: HTTPXMock, **params) -> Dict[str, Any]:
+ """
+ Helper function to abstract calling transcriber with given parameters,
+ and perform some common assertions.
+
+ Returns the body (dictionary) of the initial submission request.
+ """
+ summary = "example summary"
+
+ mock_transcript_response = factories.generate_dict_factory(
+ factories.TranscriptCompletedResponseFactory
+ )()
+
+ # Mock initial submission response
+ httpx_mock.add_response(
+ url=f"{aai.settings.base_url}/transcript",
+ status_code=httpx.codes.OK,
+ method="POST",
+ json=mock_transcript_response,
+ )
+
+ # Mock polling-for-completeness response, with mock summary result
+ httpx_mock.add_response(
+ url=f"{aai.settings.base_url}/transcript/{mock_transcript_response['id']}",
+ status_code=httpx.codes.OK,
+ method="GET",
+ json={**mock_transcript_response, "summary": summary},
+ )
+
+ # == Make API request via SDK ==
+ transcript = aai.Transcriber().transcribe(
+ data="https://example.org/audio.wav",
+ config=aai.TranscriptionConfig(
+ **params,
+ ),
+ )
+
+ # Check that submission and polling requests were made
+ assert len(httpx_mock.get_requests()) == 2
+
+ # Check that summary field from response was traced back through SDK classes
+ assert transcript.summary == summary
+
+ # Extract and return body of initial submission request
+ request = httpx_mock.get_requests()[0]
+ return json.loads(request.content.decode())
+
+
+@pytest.mark.parametrize("required_field", ["punctuate", "format_text"])
+def test_summarization_fails_without_required_field(
+ httpx_mock: HTTPXMock, required_field: str
+):
+ """
+ Tests whether the SDK raises an error before making a request
+ if `summarization` is enabled and the given required field is disabled
+ """
+ with pytest.raises(ValueError) as error:
+ __submit_request(httpx_mock, summarization=True, **{required_field: False})
+
+ # Check that the error message informs the user of the invalid parameter
+ assert required_field in str(error)
+
+ # Check that the error was raised before any requests were made
+ assert len(httpx_mock.get_requests()) == 0
+
+ # Inform httpx_mock that it's okay we didn't make any requests
+ httpx_mock.reset(False)
+
+
+def test_summarization_disabled_by_default(httpx_mock: HTTPXMock):
+ """
+ Tests that excluding `summarization` from the `TranscriptionConfig` will
+ result in the default behavior of it being excluded from the request body
+ """
+ request_body = __submit_request(httpx_mock)
+ assert request_body.get("summarization") is None
+
+
+def test_default_summarization_params(httpx_mock: HTTPXMock):
+ """
+ Tests that including `summarization=True` in the `TranscriptionConfig`
+ will result in `summarization=True` in the request body.
+ """
+ request_body = __submit_request(httpx_mock, summarization=True)
+ assert request_body.get("summarization") == True
+
+
+def test_summarization_with_params(httpx_mock: HTTPXMock):
+ """
+ Tests that including additional summarization parameters along with
+ `summarization=True` in the `TranscriptionConfig` will result in all
+ parameters being included in the request as well.
+ """
+
+ summary_model = aai.SummarizationModel.conversational
+ summary_type = aai.SummarizationType.bullets
+
+ request_body = __submit_request(
+ httpx_mock,
+ summarization=True,
+ summary_model=summary_model,
+ summary_type=summary_type,
+ )
+
+ assert request_body.get("summarization") == True
+ assert request_body.get("summary_model") == summary_model
+ assert request_body.get("summary_type") == summary_type
+
+
+def test_summarization_params_excluded_when_disabled(httpx_mock: HTTPXMock):
+ """
+ Tests that additional summarization parameters are excluded from the submission
+ request body if `summarization` itself is not enabled.
+ """
+ request_body = __submit_request(
+ httpx_mock,
+ summarization=False,
+ summary_model=aai.SummarizationModel.conversational,
+ summary_type=aai.SummarizationType.bullets,
+ )
+
+ assert request_body.get("summarization") is None
+ assert request_body.get("summary_model") is None
+ assert request_body.get("summary_type") is None