Skip to content

Commit 180a393

Browse files
author
Sam Anklesaria
committed
Attempt to fix stable ABI calls
1 parent cc592e0 commit 180a393

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

src/libtorchaudio/rnnt/compute.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#include <torch/csrc/stable/library.h>
2+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
3+
24

35
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
46
m.def(
5-
"torchaudio::rnnt_loss(Tensor logits,"
7+
"rnnt_loss(Tensor logits,"
68
"Tensor targets,"
79
"Tensor logit_lengths,"
810
"Tensor target_lengths,"
911
"int blank,"
1012
"float clamp,"
11-
"bool fused_log_softmax) -> (Tensor, Tensor?)");
13+
"bool fused_log_softmax) -> (Tensor, Tensor)");
1214
}

src/libtorchaudio/rnnt/cpu/compute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
206206
}
207207

208208
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
209-
m.impl("torchaudio::rnnt_loss", &boxed_compute);
209+
m.impl("rnnt_loss", &boxed_compute);
210210
}
211211

212212
} // namespace cpu

src/torchaudio/functional/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1763,7 +1763,7 @@ def _fix_waveform_shape(
17631763
class RnntLoss(torch.autograd.Function):
17641764
@staticmethod
17651765
def forward(ctx, *args):
1766-
output, saved = torch.ops.torchaudio.rnnt_loss(*args)
1766+
output, saved = torch.ops.torchaudio.rnnt_loss.default(*args)
17671767
ctx.save_for_backward(saved)
17681768
return output
17691769

0 commit comments

Comments
 (0)