diff --git a/kernels/portable/cpu/op_arange.cpp b/kernels/portable/cpu/op_arange.cpp index d8859032b56..0013f1a4d0f 100644 --- a/kernels/portable/cpu/op_arange.cpp +++ b/kernels/portable/cpu/op_arange.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include #include @@ -29,9 +30,7 @@ Tensor& arange_out(KernelRuntimeContext& ctx, const Scalar& end, Tensor& out) { ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out); - size_t size = static_cast(std::ceil(end_val)); - - Tensor::SizesType out_length = static_cast(size); + Tensor::SizesType out_length = compute_arange_out_size(0.0, end_val, 1.0); ET_KERNEL_CHECK( ctx, @@ -39,12 +38,7 @@ Tensor& arange_out(KernelRuntimeContext& ctx, const Scalar& end, Tensor& out) { InvalidArgument, out); - ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, "arange.out", CTYPE, [&]() { - auto out_data = out.mutable_data_ptr(); - for (size_t i = 0; i < size; i++) { - out_data[i] = static_cast(i); - } - }); + arange_out_impl(ctx, end_val, out); return out; } @@ -77,10 +71,8 @@ Tensor& arange_start_out( ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out); - double size_d = (d_end - d_start) / d_step; - size_t size = static_cast(std::ceil(size_d)); - - Tensor::SizesType out_length = static_cast(size); + Tensor::SizesType out_length = + compute_arange_out_size(d_start, d_end, d_step); ET_KERNEL_CHECK( ctx, @@ -88,13 +80,7 @@ Tensor& arange_start_out( InvalidArgument, out); - ET_SWITCH_REALHBF16_TYPES( - out.scalar_type(), ctx, "arange.start_out", CTYPE, [&]() { - auto out_data = out.mutable_data_ptr(); - for (size_t i = 0; i < size; i++) { - out_data[i] = convert(d_start + i * d_step); - } - }); + arange_out_impl(ctx, d_start, d_end, d_step, out); return out; } diff --git a/kernels/portable/cpu/util/arange_util.cpp b/kernels/portable/cpu/util/arange_util.cpp new file mode 100644 index 00000000000..e13f7652736 --- /dev/null +++ b/kernels/portable/cpu/util/arange_util.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch::executor::native { +#define ET_ARANGE_IMPL(ctx, start, numel, step, out, op_name) \ + ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE, [&]() { \ + auto out_data = out.mutable_data_ptr(); \ + for (Tensor::SizesType i = 0; i < numel; ++i) { \ + out_data[i] = static_cast(start + i * step); \ + } \ + }) + +Tensor::SizesType +compute_arange_out_size(double start, double end, double step) { + Tensor::SizesType numel = + static_cast(std::ceil((end - start) / step)); + + ET_CHECK_MSG( + numel >= 0, + "numel should be non-negative, but got (%d). start (%f), end (%f), step (%f)", + numel, + start, + end, + step); + return numel; +} + +void arange_out_impl( + KernelRuntimeContext& ctx, + double start, + double end, + double step, + Tensor& out) { + (void)ctx; + Tensor::SizesType numel = compute_arange_out_size(start, end, step); + ET_ARANGE_IMPL(ctx, start, numel, step, out, "arange.start_out"); +} + +void arange_out_impl(KernelRuntimeContext& ctx, double end, Tensor& out) { + (void)ctx; + ET_ARANGE_IMPL(ctx, 0.0, end, 1.0, out, "arange.out"); +} + +} // namespace torch::executor::native diff --git a/kernels/portable/cpu/util/arange_util.h b/kernels/portable/cpu/util/arange_util.h new file mode 100644 index 00000000000..5abb52f410c --- /dev/null +++ b/kernels/portable/cpu/util/arange_util.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch::executor::native { + +Tensor::SizesType +compute_arange_out_size(double start, double end, double step); + +inline Tensor::SizesType compute_arange_out_size(double end) { + return compute_arange_out_size(0.0, end, 1.0); +} + +void arange_out_impl( + KernelRuntimeContext& ctx, + double start, + double end, + double step, + Tensor& out); + +void arange_out_impl(KernelRuntimeContext& ctx, double end, Tensor& out); + +inline void +arange_out_impl(double start, double end, double step, Tensor& out) { + KernelRuntimeContext ctx; + arange_out_impl(ctx, start, end, step, out); +} + +inline void arange_out_impl(double end, Tensor& out) { + KernelRuntimeContext ctx; + arange_out_impl(ctx, 0.0, end, 1.0, out); +} +} // namespace torch::executor::native diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index abf3f22c00b..d7ee1ac89ce 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -13,6 +13,7 @@ def define_common_targets(): name = "all_deps", exported_deps = [ "//executorch/extension/threadpool:threadpool", + "//executorch/kernels/portable/cpu/util:arange_util", "//executorch/kernels/portable/cpu/util:functional_util", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/kernels/portable/cpu/util:kernel_ops_util", @@ -294,6 +295,19 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/..."], ) + runtime.cxx_library( + name = "arange_util", + srcs = ["arange_util.cpp"], + exported_headers = ["arange_util.h"], + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/kernels/portable/cpu/...", + "//executorch/extension/llm/...", + ], + ) + runtime.cxx_library( name = "broadcast_indexes_range", exported_headers = ["broadcast_indexes_range.h"], diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index 1ae20ca7c61..a2ccaec9e03 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -268,6 +268,7 @@ ATEN_OPS = ( op_target( name = "op_arange", deps = [ + "//executorch/kernels/portable/cpu/util:arange_util", "//executorch/kernels/portable/cpu/util:kernel_ops_util", "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util",