Skip to content

Commit 2357242

Browse files
author
Sam Anklesaria
committed
Aggregate TORCH_LIBRARY calls
1 parent 031c240 commit 2357242

File tree

6 files changed

+16
-26
lines changed

6 files changed

+16
-26
lines changed

src/libtorchaudio/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ if(BUILD_RNNT)
2525
rnnt/cpu/compute_alphas.cpp
2626
rnnt/cpu/compute_betas.cpp
2727
rnnt/cpu/compute.cpp
28-
rnnt/compute_alphas.cpp
29-
rnnt/compute_betas.cpp
3028
rnnt/compute.cpp
3129
)
3230
if (USE_CUDA)

src/libtorchaudio/rnnt/compute.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,18 @@ STABLE_TORCH_LIBRARY(torchaudio, m) {
99
"int blank,"
1010
"float clamp,"
1111
"bool fused_log_softmax) -> (Tensor, Tensor?)");
12+
m.def(
13+
"rnnt_loss_betas(Tensor logits,"
14+
"Tensor targets,"
15+
"Tensor logit_lengths,"
16+
"Tensor target_lengths,"
17+
"int blank,"
18+
"float clamp) -> Tensor");
19+
m.def(
20+
"rnnt_loss_alphas(Tensor logits,"
21+
"Tensor targets,"
22+
"Tensor logit_lengths,"
23+
"Tensor target_lengths,"
24+
"int blank,"
25+
"float clamp) -> Tensor");
1226
}

src/libtorchaudio/rnnt/compute_alphas.cpp

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/libtorchaudio/rnnt/compute_betas.cpp

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/libtorchaudio/rnnt/gpu/compute.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
210210
}
211211

212212
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
213-
m.impl("rnnt_loss", &boxed_compute);
213+
m.impl("torchaudio::rnnt_loss", &boxed_compute);
214214
}
215215

216216
} // namespace gpu

src/torchaudio/functional/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1763,7 +1763,7 @@ def _fix_waveform_shape(
17631763
class RnntLoss(torch.autograd.Function):
17641764
@staticmethod
17651765
def forward(ctx, *args):
1766-
output, saved = torch.ops.torchaudio.rnnt_loss_forward(*args)
1766+
output, saved = torch.ops.torchaudio.rnnt_loss(*args)
17671767
ctx.save_for_backward(saved)
17681768
return output
17691769

0 commit comments

Comments
 (0)