Skip to content

feat: allow passing TranscriptionConfig to Transcriber #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,51 @@ Visit one of our Playgrounds:
- [Transcription Playground](https://www.assemblyai.com/playground)


# Advanced (TODO)
# Advanced

## How the SDK handles Default Configurations

### Defining Defaults

When no `TranscriptionConfig` is being passed to the `Transcriber` or its methods, it will use a default instance of a `TranscriptionConfig`.

If you would like to re-use the same `TranscriptionConfig` for all your transcriptions,
you can set it on the `Transcriber` directly:

```python
config = aai.TranscriptionConfig(punctuate=False, format_text=False)

transcriber = aai.Transcriber(config=config)

# will use the same config for all `.transcribe*(...)` operations
transcriber.transcribe("https://example.org/audio.wav")
```

### Overriding Defaults

You can override the default configuration later via the `.config` property of the `Transcriber`:

```python
transcriber = aai.Transcriber()

# override the `Transcriber`'s config with a new config
transcriber.config = aai.TranscriptionConfig(punctuate=False, format_text=False)
```


In case you want to override the `Transcriber`'s configuration for a specific operation with a different one, you can do so via the `config` parameter of a `.transcribe*(...)` method:

```python
config = aai.TranscriptionConfig(punctuate=False, format_text=False)
# set a default configuration
transcriber = aai.Transcriber(config=config)

transcriber.transcribe(
"https://example.com/audio.mp3",
# overrides the above configuration on the `Transcriber` with the following
config=aai.TranscriptionConfig(dual_channel=True, disfluencies=True)
)
```

## Synchronous vs Asynchronous

Expand Down
66 changes: 52 additions & 14 deletions assemblyai/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,10 @@ def __init__(
self,
*,
client: _client.Client,
config: types.TranscriptionConfig,
) -> None:
self._client = client
self.config = config

def transcribe_url(
self,
Expand Down Expand Up @@ -517,7 +519,7 @@ def transcribe(
poll: bool,
) -> Transcript:
if config is None:
config = types.TranscriptionConfig()
config = self.config

if urlparse(data).scheme in {"http", "https"}:
return self.transcribe_url(
Expand All @@ -536,9 +538,12 @@ def transcribe_group(
self,
*,
data: List[str],
config: types.TranscriptionConfig,
config: Optional[types.TranscriptionConfig],
poll: bool,
) -> TranscriptGroup:
if config is None:
config = self.config

executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
future_transcripts: Dict[concurrent.futures.Future[Transcript], str] = {}

Expand Down Expand Up @@ -576,6 +581,7 @@ def __init__(
self,
*,
client: Optional[_client.Client] = None,
config: Optional[types.TranscriptionConfig] = None,
max_workers: Optional[int] = None,
) -> None:
"""
Expand All @@ -584,13 +590,29 @@ def __init__(
Args:
`client`: The `Client` to use for the `Transcriber`. If `None` is given, the
default settings for the `Client` will be used.
`config`: The default configuration for the `Transcriber`. If `None` is given,
the default configuration of a `TranscriptionConfig` will be used.
`max_workers`: The maximum number of parallel jobs when using the `_async`
methods on the `Transcriber`. By default it uses `os.cpu_count() - 1`

Example:
To use the `Transcriber` with the default settings, you can simply do:
```
transcriber = aai.Transcriber()
```

To use the `Transcriber` with a custom configuration, you can do:
```
config = aai.TranscriptionConfig(punctuate=False, format_text=False)

transcriber = aai.Transcriber(config=config)
```
"""
self._client = client or _client.Client.get_default()

self._impl = _TranscriberImpl(
client=self._client,
config=config or types.TranscriptionConfig(),
)

if not max_workers:
Expand All @@ -600,6 +622,23 @@ def __init__(
max_workers=max_workers,
)

@property
def config(self) -> types.TranscriptionConfig:
"""
Returns the default configuration of the `Transcriber`.
"""
return self._impl.config

@config.setter
def config(self, config: types.TranscriptionConfig) -> None:
"""
Sets the default configuration of the `Transcriber`.

Args:
`config`: The new default configuration.
"""
self._impl.config = config

def submit(
self,
data: str,
Expand All @@ -610,7 +649,8 @@ def submit(

Args:
data: An URL or a local file (as path)
config: Transcription options and features.
config: Transcription options and features. If `None` is given, the Transcriber's
default configuration will be used.
"""
return self._impl.transcribe(
data=data,
Expand All @@ -628,8 +668,8 @@ def transcribe(

Args:
data: An URL or a local file (as path)
config: Transcription options and features.
poll: Whether the transcript should be polled for its completion.
config: Transcription options and features. If `None` is given, the Transcriber's
default configuration will be used.
"""

return self._impl.transcribe(
Expand All @@ -648,8 +688,8 @@ def transcribe_async(

Args:
data: An URL or a local file (as path)
config: Transcription options and features.
poll: Whether the transcript should be polled for its completion.
config: Transcription options and features. If `None` is given, the Transcriber's
default configuration will be used.
"""

return self._executor.submit(
Expand All @@ -669,11 +709,9 @@ def transcribe_group(

Args:
data: A list of paths or URLs (can be mixed)
config: Transcription options and features.
poll: Whether the transcripts should be polled for their completion.
config: Transcription options and features. If `None` is given, the Transcriber's
default configuration will be used.
"""
if config is None:
config = types.TranscriptionConfig()

return self._impl.transcribe_group(
data=data,
Expand All @@ -683,7 +721,7 @@ def transcribe_group(

def transcribe_group_async(
self,
data: str,
data: List[str],
config: Optional[types.TranscriptionConfig] = None,
) -> concurrent.futures.Future[TranscriptGroup]:
"""
Expand All @@ -692,8 +730,8 @@ def transcribe_group_async(

Args:
data: A list of paths or URLs (can be mixed)
config: Transcription options and features.
poll: Whether the transcripts should be polled for their completion.
config: Transcription options and features. If `None` is given, the Transcriber's
default configuration will be used.
"""

return self._executor.submit(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="assemblyai",
version="0.3.3",
version="0.4.0",
description="AssemblyAI Python SDK",
author="AssemblyAI",
author_email="[email protected]",
Expand Down