diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index f311c8370e..559b55437a 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -30,5 +30,5 @@ fi ( cd test - pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs" + pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not (torchscript and rnnt)" ) diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 713cb50533..85bc227cd6 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -28,7 +28,6 @@ if(BUILD_RNNT) rnnt/compute_alphas.cpp rnnt/compute_betas.cpp rnnt/compute.cpp - rnnt/autograd.cpp ) if (USE_CUDA) list( diff --git a/src/libtorchaudio/rnnt/autograd.cpp b/src/libtorchaudio/rnnt/autograd.cpp deleted file mode 100644 index dcf68409ed..0000000000 --- a/src/libtorchaudio/rnnt/autograd.cpp +++ /dev/null @@ -1,69 +0,0 @@ -#include - -namespace torchaudio { -namespace rnnt { - -class RNNTLossFunction : public torch::autograd::Function { - public: - static torch::autograd::tensor_list forward( - torch::autograd::AutogradContext* ctx, - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - torch::Tensor undef; - auto result = rnnt_loss( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); - auto costs = std::get<0>(result); - auto grads = std::get<1>(result).value_or(undef); - ctx->save_for_backward({grads}); - return {costs, grads}; - } - - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto grad = saved[0]; - auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); - auto result = grad * grad_out; - torch::Tensor undef; - return {result, undef, undef, undef, undef, undef, undef, undef}; - } -}; - -std::tuple> rnnt_loss_autograd( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - at::AutoDispatchBelowADInplaceOrView guard; - auto results = RNNTLossFunction::apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); - return std::make_tuple(results[0], results[1]); -} - -TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { - m.impl("rnnt_loss", rnnt_loss_autograd); -} - -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/compute.cpp b/src/libtorchaudio/rnnt/compute.cpp index 567c9b5d4b..5aba334cee 100644 --- a/src/libtorchaudio/rnnt/compute.cpp +++ b/src/libtorchaudio/rnnt/compute.cpp @@ -30,4 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { "int blank," "float clamp," "bool fused_log_softmax) -> (Tensor, Tensor?)"); + m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); } diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 42dde06814..f955fe7840 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1760,6 +1760,19 @@ def _fix_waveform_shape( waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:]) return waveform_shift +class RnntLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + output, saved = torch.ops.torchaudio.rnnt_loss_forward(*args) + ctx.save_for_backward(saved) + return output + + @staticmethod + def backward(ctx, dy): + grad = ctx.saved_tensors[0] + grad_out = dy.view((-1, 1, 1, 1)) + result = grad * grad_out; + return (result, None, None, None, None, None, None, None) def _rnnt_loss( logits: Tensor, @@ -1803,14 +1816,14 @@ def _rnnt_loss( if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank - costs, _ = torch.ops.torchaudio.rnnt_loss( - logits=logits, - targets=targets, - logit_lengths=logit_lengths, - target_lengths=target_lengths, - blank=blank, - clamp=clamp, - fused_log_softmax=fused_log_softmax, + costs = RnntLoss.apply( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + fused_log_softmax ) if reduction == "mean":