Skip to content

feat: add summarization functionality #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<img src="https://github.com/AssemblyAI/assemblyai-python-sdk/blob/master/assemblyai.png?raw=true" width="500"/>

---

[![CI Passing](https://github.com/AssemblyAI/assemblyai-python-sdk/actions/workflows/test.yml/badge.svg)](https://github.com/AssemblyAI/assemblyai-python-sdk/actions/workflows/test.yml)
[![GitHub License](https://img.shields.io/github/license/AssemblyAI/assemblyai-python-sdk)](https://github.com/AssemblyAI/assemblyai-python-sdk/blob/master/LICENSE)
[![PyPI version](https://badge.fury.io/py/assemblyai.svg)](https://badge.fury.io/py/assemblyai)
Expand All @@ -26,14 +27,12 @@ With a single API call, get access to AI models built on the latest AI breakthro
- [Playgrounds](#playgrounds)
- [Advanced](#advanced-todo)


# Documentation

Visit our [AssemblyAI API Documentation](https://www.assemblyai.com/docs) to get an overview of our models!

# Quick Start


## Installation

```bash
Expand Down Expand Up @@ -66,6 +65,7 @@ transcript = transcriber.transcribe("./my-local-audio-file.wav")

print(transcript.text)
```

</details>

<details>
Expand All @@ -79,6 +79,7 @@ transcript = transcriber.transcribe("https://example.org/audio.mp3")

print(transcript.text)
```

</details>

<details>
Expand All @@ -96,6 +97,7 @@ print(transcript.export_subtitles_srt())
# in VTT format
print(transcript.export_subtitles_vtt())
```

</details>

<details>
Expand All @@ -115,6 +117,7 @@ paragraphs = transcript.get_paragraphs()
for paragraph in paragraphs:
print(paragraph.text)
```

</details>

<details>
Expand All @@ -131,6 +134,7 @@ matches = transcript.word_search(["price", "product"])
for match in matches:
print(f"Found '{match.text}' {match.count} times in the transcript")
```

</details>

<details>
Expand All @@ -152,9 +156,40 @@ transcript = transcriber.transcribe("https://example.org/audio.mp3", config)

print(transcript.text)
```

</details>

<details>
<summary>Summarize the content of a transcript</summary>

```python
import assemblyai as aai

transcriber = aai.Transcriber()
transcript = transcriber.transcribe(
"https://example.org/audio.mp3",
config=aai.TranscriptionConfig(summarize=True)
)

print(transcript.summary)
```

By default, the summarization model will be `informative` and the summarization type will be `bullets`. [Read more about summarization models and types here](https://www.assemblyai.com/docs/Models/summarization#types-and-models).

To change the model and/or type, pass additional parameters to the `TranscriptionConfig`:

```python
config=aai.TranscriptionConfig(
summarize=True,
summary_model=aai.SummarizationModel.catchy,
summary_type=aai.Summarizationtype.headline
)
```

</details>

---

### **LeMUR Examples**

<details>
Expand All @@ -175,6 +210,7 @@ summary = transcript_group.lemur.summarize(context="Customers asking for cars",

print(summary)
```

</details>

<details>
Expand All @@ -195,6 +231,7 @@ feedback = transcript_group.lemur.ask_coach(context="Who was the best interviewe

print(feedback)
```

</details>

<details>
Expand All @@ -218,6 +255,7 @@ for result in result:
print(f"Question: {result.question}")
print(f"Answer: {result.answer}")
```

</details>

---
Expand Down Expand Up @@ -247,8 +285,8 @@ config.set_pii_redact(
transcriber = aai.Transcriber()
transcript = transcriber.transcribe("https://example.org/audio.mp3", config)
```
</details>

</details>

---

Expand Down
6 changes: 6 additions & 0 deletions assemblyai/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def text(self) -> Optional[str]:

return self._impl.transcript.text

@property
def summary(self) -> Optional[str]:
"The summarization of the transcript"

return self._impl.transcript.summary

@property
def status(self) -> types.TranscriptStatus:
"The current status of the transcript"
Expand Down
13 changes: 13 additions & 0 deletions assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,16 @@ def set_summarize(

return self

# Validate that required parameters are also set
if self._raw_transcription_config.punctuate == False:
raise ValueError(
"If `summarization` is enabled, then `punctuate` must not be disabled"
)
if self._raw_transcription_config.format_text == False:
raise ValueError(
"If `summarization` is enabled, then `format_text` must not be disabled"
)

self._raw_transcription_config.summarization = True
self._raw_transcription_config.summary_model = model
self._raw_transcription_config.summary_type = type
Expand Down Expand Up @@ -1379,6 +1389,9 @@ class TranscriptResponse(BaseTranscript):
webhook_auth: Optional[bool]
"Whether the webhook was sent with an HTTP authentication header"

summary: Optional[str]
"The summarization of the transcript"

# auto_highlights_result: Optional[AutohighlightResponse] = None
# "The list of results when enabling Automatic Transcript Highlights"

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="assemblyai",
version="0.5.1",
version="0.6.0",
description="AssemblyAI Python SDK",
author="AssemblyAI",
author_email="[email protected]",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class Meta:
audio_duration = factory.Faker("pyint")


def generate_dict_factory(f: factory.Factory) -> Callable[[None], Dict[str, Any]]:
def generate_dict_factory(f: factory.Factory) -> Callable[[], Dict[str, Any]]:
"""
Creates a dict factory from the given *Factory class.

Expand Down
137 changes: 137 additions & 0 deletions tests/unit/test_summarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import json
from typing import Any, Dict

import httpx
import pytest
from pytest_httpx import HTTPXMock

import assemblyai as aai
from tests.unit import factories

aai.settings.api_key = "test"


def __submit_request(httpx_mock: HTTPXMock, **params) -> Dict[str, Any]:
"""
Helper function to abstract calling transcriber with given parameters,
and perform some common assertions.

Returns the body (dictionary) of the initial submission request.
"""
summary = "example summary"

mock_transcript_response = factories.generate_dict_factory(
factories.TranscriptCompletedResponseFactory
)()

# Mock initial submission response
httpx_mock.add_response(
url=f"{aai.settings.base_url}/transcript",
status_code=httpx.codes.OK,
method="POST",
json=mock_transcript_response,
)

# Mock polling-for-completeness response, with mock summary result
httpx_mock.add_response(
url=f"{aai.settings.base_url}/transcript/{mock_transcript_response['id']}",
status_code=httpx.codes.OK,
method="GET",
json={**mock_transcript_response, "summary": summary},
)

# == Make API request via SDK ==
transcript = aai.Transcriber().transcribe(
data="https://example.org/audio.wav",
config=aai.TranscriptionConfig(
**params,
),
)

# Check that submission and polling requests were made
assert len(httpx_mock.get_requests()) == 2

# Check that summary field from response was traced back through SDK classes
assert transcript.summary == summary

# Extract and return body of initial submission request
request = httpx_mock.get_requests()[0]
return json.loads(request.content.decode())


@pytest.mark.parametrize("required_field", ["punctuate", "format_text"])
def test_summarization_fails_without_required_field(
httpx_mock: HTTPXMock, required_field: str
):
"""
Tests whether the SDK raises an error before making a request
if `summarization` is enabled and the given required field is disabled
"""
with pytest.raises(ValueError) as error:
__submit_request(httpx_mock, summarization=True, **{required_field: False})

# Check that the error message informs the user of the invalid parameter
assert required_field in str(error)

# Check that the error was raised before any requests were made
assert len(httpx_mock.get_requests()) == 0

# Inform httpx_mock that it's okay we didn't make any requests
httpx_mock.reset(False)


def test_summarization_disabled_by_default(httpx_mock: HTTPXMock):
"""
Tests that excluding `summarization` from the `TranscriptionConfig` will
result in the default behavior of it being excluded from the request body
"""
request_body = __submit_request(httpx_mock)
assert request_body.get("summarization") is None


def test_default_summarization_params(httpx_mock: HTTPXMock):
"""
Tests that including `summarization=True` in the `TranscriptionConfig`
will result in `summarization=True` in the request body.
"""
request_body = __submit_request(httpx_mock, summarization=True)
assert request_body.get("summarization") == True


def test_summarization_with_params(httpx_mock: HTTPXMock):
"""
Tests that including additional summarization parameters along with
`summarization=True` in the `TranscriptionConfig` will result in all
parameters being included in the request as well.
"""

summary_model = aai.SummarizationModel.conversational
summary_type = aai.SummarizationType.bullets

request_body = __submit_request(
httpx_mock,
summarization=True,
summary_model=summary_model,
summary_type=summary_type,
)

assert request_body.get("summarization") == True
assert request_body.get("summary_model") == summary_model
assert request_body.get("summary_type") == summary_type


def test_summarization_params_excluded_when_disabled(httpx_mock: HTTPXMock):
"""
Tests that additional summarization parameters are excluded from the submission
request body if `summarization` itself is not enabled.
"""
request_body = __submit_request(
httpx_mock,
summarization=False,
summary_model=aai.SummarizationModel.conversational,
summary_type=aai.SummarizationType.bullets,
)

assert request_body.get("summarization") is None
assert request_body.get("summary_model") is None
assert request_body.get("summary_type") is None