From 4322f8087c23743fe6afe76092dab51e30ef4968 Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Sun, 6 Apr 2025 10:57:33 -0700 Subject: [PATCH] Support slice ops with default start (#9923) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/9923 Since D71962884, we see the following slice ops in ASR encoder: {F1976830836} This is causing failure during XNNPack delegation, since XNNPack slice pass is trying to compare start_idx 'None' to 0. This diff fixes that. Reviewed By: mcr229 Differential Revision: D72503552 --- backends/xnnpack/operators/op_slice_copy.py | 4 +++- backends/xnnpack/test/ops/test_slice_copy.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/operators/op_slice_copy.py b/backends/xnnpack/operators/op_slice_copy.py index 40d8e5f04eb..d9056afa832 100644 --- a/backends/xnnpack/operators/op_slice_copy.py +++ b/backends/xnnpack/operators/op_slice_copy.py @@ -69,7 +69,9 @@ def define_node( output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC] dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice] - slice_begin_index = cast(int, node.args[2]) + slice_begin_index = 0 + if len(node.args) > 2 and node.args[2]: + slice_begin_index = cast(int, node.args[2]) if slice_begin_index < 0: slice_begin_index = input_shape[dim_of_slice] + slice_begin_index diff --git a/backends/xnnpack/test/ops/test_slice_copy.py b/backends/xnnpack/test/ops/test_slice_copy.py index ea65571b1e8..857c78480ad 100644 --- a/backends/xnnpack/test/ops/test_slice_copy.py +++ b/backends/xnnpack/test/ops/test_slice_copy.py @@ -69,6 +69,18 @@ def forward(self, x): # Note that two of the slices are optimized away as they are identity. self._test_slice_copy(ConvSlice(), inputs, 4, 2) + def test_fp32_slice_copy_default_start(self): + """ + XNNPACK supports default start in slice op. + """ + + class Slice(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.slice.Tensor(x, 0, None, 2) + + inputs = (torch.randn(5, 5),) + self._test_slice_copy(Slice(), inputs, 1, 1) + def test_fp32_slice_copy_stride_non_1(self): """ XNNPACK does not support strided slicing.