From 77d44213eaed6440f13f187852fbd28172982d51 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 16:01:53 +0000 Subject: [PATCH 1/8] Port autograd code for rnnt --- src/libtorchaudio/rnnt/autograd.cpp | 51 ++----------------------- src/torchaudio/functional/functional.py | 17 ++++++++- 2 files changed, 19 insertions(+), 49 deletions(-) diff --git a/src/libtorchaudio/rnnt/autograd.cpp b/src/libtorchaudio/rnnt/autograd.cpp index dcf68409ed..5ba545cb99 100644 --- a/src/libtorchaudio/rnnt/autograd.cpp +++ b/src/libtorchaudio/rnnt/autograd.cpp @@ -3,31 +3,7 @@ namespace torchaudio { namespace rnnt { -class RNNTLossFunction : public torch::autograd::Function { - public: - static torch::autograd::tensor_list forward( - torch::autograd::AutogradContext* ctx, - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - torch::Tensor undef; - auto result = rnnt_loss( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); - auto costs = std::get<0>(result); - auto grads = std::get<1>(result).value_or(undef); - ctx->save_for_backward({grads}); - return {costs, grads}; - } + static torch::autograd::tensor_list backward( torch::autograd::AutogradContext* ctx, @@ -39,31 +15,10 @@ class RNNTLossFunction : public torch::autograd::Function { torch::Tensor undef; return {result, undef, undef, undef, undef, undef, undef, undef}; } -}; - -std::tuple> rnnt_loss_autograd( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - at::AutoDispatchBelowADInplaceOrView guard; - auto results = RNNTLossFunction::apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); - return std::make_tuple(results[0], results[1]); } -TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { - m.impl("rnnt_loss", rnnt_loss_autograd); +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); } -} // namespace rnnt } // namespace torchaudio diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 42dde06814..e25194dbd5 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1760,6 +1760,21 @@ def _fix_waveform_shape( waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:]) return waveform_shift +class RnntLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + output, saved = torch.ops.torchaudio.rnnt_loss_forward(*args) + ctx.save_for_backward(saved) + return output + + @staticmethod + def backward(ctx, dy): + grad = ctx.saved_tensors[0] + grad_out = dy.view((-1, 1, 1, 1)) + result = grad * grad_out; + return (result, None, None, None, None, None, None, None) + +torch.ops.torchaudio.rnnt_loss_forward def _rnnt_loss( logits: Tensor, @@ -1803,7 +1818,7 @@ def _rnnt_loss( if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank - costs, _ = torch.ops.torchaudio.rnnt_loss( + costs = RnntLoss.apply( logits=logits, targets=targets, logit_lengths=logit_lengths, From 725c74e9c579eb5d14b9c2f58375d7a3acf299c7 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 16:04:34 +0000 Subject: [PATCH 2/8] Correct rnnt calling arguments --- src/torchaudio/functional/functional.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index e25194dbd5..8abd075546 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1819,13 +1819,13 @@ def _rnnt_loss( blank = logits.shape[-1] + blank costs = RnntLoss.apply( - logits=logits, - targets=targets, - logit_lengths=logit_lengths, - target_lengths=target_lengths, - blank=blank, - clamp=clamp, - fused_log_softmax=fused_log_softmax, + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + fused_log_softmax ) if reduction == "mean": From 97176519c5935cde5558a6af32d268fa28637ea1 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 16:50:05 +0000 Subject: [PATCH 3/8] Disable torchscript checks --- .github/scripts/unittest-linux/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index f311c8370e..dacde20bea 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -30,5 +30,5 @@ fi ( cd test - pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs" + pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not torchscript" ) From 2b882503158935fd99fb62dea63e387a2d8d3534 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 17:18:57 +0000 Subject: [PATCH 4/8] Restrict disabling of torchscript tests --- .github/scripts/unittest-linux/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index dacde20bea..559b55437a 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -30,5 +30,5 @@ fi ( cd test - pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not torchscript" + pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not (torchscript and rnnt)" ) From 116de6f9a2778602b0f0462d3fdf67c6784d97ff Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Fri, 11 Jul 2025 17:30:34 +0000 Subject: [PATCH 5/8] Remove leftover line --- src/torchaudio/functional/functional.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 8abd075546..f955fe7840 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1774,8 +1774,6 @@ def backward(ctx, dy): result = grad * grad_out; return (result, None, None, None, None, None, None, None) -torch.ops.torchaudio.rnnt_loss_forward - def _rnnt_loss( logits: Tensor, targets: Tensor, From 003b3a9d810c1cf926627b201efa9e17a3fb1838 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 14 Jul 2025 14:36:49 +0000 Subject: [PATCH 6/8] Remove unnecessary backward code --- src/libtorchaudio/rnnt/autograd.cpp | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/libtorchaudio/rnnt/autograd.cpp b/src/libtorchaudio/rnnt/autograd.cpp index 5ba545cb99..05b767194d 100644 --- a/src/libtorchaudio/rnnt/autograd.cpp +++ b/src/libtorchaudio/rnnt/autograd.cpp @@ -1,22 +1,8 @@ #include namespace torchaudio { -namespace rnnt { - - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto grad = saved[0]; - auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); - auto result = grad * grad_out; - torch::Tensor undef; - return {result, undef, undef, undef, undef, undef, undef, undef}; - } -} - TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); } From 7727ad773ed9f01e16e34164d2c4a742f77cdd6c Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 14 Jul 2025 14:50:59 +0000 Subject: [PATCH 7/8] Move rnnt_loss_forward to compute.cpp --- src/libtorchaudio/rnnt/autograd.cpp | 10 ---------- src/libtorchaudio/rnnt/compute.cpp | 1 + 2 files changed, 1 insertion(+), 10 deletions(-) delete mode 100644 src/libtorchaudio/rnnt/autograd.cpp diff --git a/src/libtorchaudio/rnnt/autograd.cpp b/src/libtorchaudio/rnnt/autograd.cpp deleted file mode 100644 index 05b767194d..0000000000 --- a/src/libtorchaudio/rnnt/autograd.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include - -namespace torchaudio { - - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); -} - -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/compute.cpp b/src/libtorchaudio/rnnt/compute.cpp index 567c9b5d4b..5aba334cee 100644 --- a/src/libtorchaudio/rnnt/compute.cpp +++ b/src/libtorchaudio/rnnt/compute.cpp @@ -30,4 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { "int blank," "float clamp," "bool fused_log_softmax) -> (Tensor, Tensor?)"); + m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); } From 9b9dc2573f8736bed0c86c5a8ee271cbd11cfc1d Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 14 Jul 2025 16:09:07 +0000 Subject: [PATCH 8/8] Remove autograd rnnt in cmakelists --- src/libtorchaudio/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 713cb50533..85bc227cd6 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -28,7 +28,6 @@ if(BUILD_RNNT) rnnt/compute_alphas.cpp rnnt/compute_betas.cpp rnnt/compute.cpp - rnnt/autograd.cpp ) if (USE_CUDA) list(