Skip to content

Commit ccb3589

Browse files
committed
[llm] Add arange() tensor maker API
As titled. `arange()` taking a `sizes` argument to be able to take custom tensor sizes. Differential Revision: [D77184741](https://our.internmc.facebook.com/intern/diff/D77184741/) ghstack-source-id: 292190468 Pull Request resolved: #11861
1 parent 608a745 commit ccb3589

File tree

3 files changed

+390
-0
lines changed

3 files changed

+390
-0
lines changed

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,5 +186,111 @@ TensorPtr randint_strided(
186186
std::uniform_int_distribution<int64_t>(low, high - 1));
187187
}
188188

189+
TensorPtr arange(
190+
executorch::aten::Scalar start,
191+
executorch::aten::Scalar end,
192+
executorch::aten::Scalar step,
193+
std::vector<executorch::aten::SizesType> sizes,
194+
executorch::aten::ScalarType type,
195+
executorch::aten::TensorShapeDynamism dynamism) {
196+
// Calculate the number of elements in the range
197+
double start_val, end_val, step_val;
198+
199+
if (start.isFloatingPoint()) {
200+
start_val = start.to<double>();
201+
} else if (start.isIntegral(/*includeBool=*/false)) {
202+
start_val = static_cast<double>(start.to<int64_t>());
203+
} else {
204+
ET_CHECK_MSG(false, "start must be a number");
205+
}
206+
207+
if (end.isFloatingPoint()) {
208+
end_val = end.to<double>();
209+
} else if (end.isIntegral(/*includeBool=*/false)) {
210+
end_val = static_cast<double>(end.to<int64_t>());
211+
} else {
212+
ET_CHECK_MSG(false, "end must be a number");
213+
}
214+
215+
if (step.isFloatingPoint()) {
216+
step_val = step.to<double>();
217+
} else if (step.isIntegral(/*includeBool=*/false)) {
218+
step_val = static_cast<double>(step.to<int64_t>());
219+
} else {
220+
ET_CHECK_MSG(false, "step must be a number");
221+
}
222+
223+
ET_CHECK_MSG(step_val != 0, "step cannot be zero");
224+
225+
// Calculate the number of elements
226+
int64_t numel =
227+
static_cast<int64_t>(std::ceil((end_val - start_val) / step_val));
228+
numel = std::max(int64_t(0), numel); // Ensure non-negative
229+
230+
// Validate sizes compatibility with numel
231+
if (!sizes.empty()) {
232+
int64_t negative_one_count = 0;
233+
int64_t negative_one_index = -1;
234+
int64_t product = 1;
235+
236+
// Count -1s and calculate product of positive dimensions
237+
for (size_t i = 0; i < sizes.size(); ++i) {
238+
if (sizes[i] == -1) {
239+
negative_one_count++;
240+
negative_one_index = static_cast<int64_t>(i);
241+
} else if (sizes[i] <= 0) {
242+
ET_CHECK_MSG(false, "sizes must contain positive integers or -1");
243+
} else {
244+
product *= sizes[i];
245+
}
246+
}
247+
248+
// Check that there's at most one -1
249+
ET_CHECK_MSG(negative_one_count <= 1, "sizes can contain at most one -1");
250+
251+
if (negative_one_count == 1) {
252+
// Infer the -1 dimension
253+
ET_CHECK_MSG(
254+
numel % product == 0,
255+
"numel (%" PRId64
256+
") is not divisible by the product of known dimensions (%" PRId64 ")",
257+
numel,
258+
product);
259+
int64_t inferred_size = numel / product;
260+
ET_CHECK_MSG(
261+
inferred_size > 0,
262+
"inferred dimension size must be positive, got %" PRId64,
263+
inferred_size);
264+
// Update the sizes vector with the inferred dimension
265+
sizes[negative_one_index] = inferred_size;
266+
} else {
267+
// No -1, check exact match
268+
ET_CHECK_MSG(
269+
product == numel,
270+
"product of sizes (%" PRId64 ") does not match numel (%" PRId64 ")",
271+
product,
272+
numel);
273+
}
274+
}
275+
276+
// Create tensor with the provided sizes or default to 1D
277+
std::vector<executorch::aten::SizesType> tensor_sizes = sizes.empty()
278+
? std::vector<executorch::aten::SizesType>{static_cast<
279+
executorch::aten::SizesType>(numel)}
280+
: sizes;
281+
282+
auto tensor = empty(tensor_sizes, type, dynamism);
283+
284+
// Fill the tensor with values from start to end with step
285+
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "arange", CTYPE, [&] {
286+
CTYPE* data = tensor->mutable_data_ptr<CTYPE>();
287+
for (int64_t i = 0; i < numel; ++i) {
288+
data[i] = static_cast<CTYPE>(start_val + i * step_val);
289+
}
290+
});
291+
292+
return tensor;
293+
}
294+
189295
} // namespace extension
190296
} // namespace executorch

extension/tensor/tensor_ptr_maker.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,5 +683,47 @@ inline TensorPtr randint(
683683
return randint_strided(low, high, std::move(sizes), {}, type, dynamism);
684684
}
685685

686+
/**
687+
* Creates a tensor with values from `start` to `end` (exclusive) with step size
688+
* `step`. This API will error out if `sizes` is not compatible with the number
689+
* of elements of the output tensor. If `sizes` is empty, the result tensor will
690+
* be 1D.
691+
*
692+
* @param start The starting value of the sequence.
693+
* @param end The ending value of the sequence (exclusive).
694+
* @param step The step size between values in the sequence.
695+
* @param sizes A vector specifying the size of each dimension. Only 1
696+
* occurrence of -1 is allowed, the sizes need to match the number of elements
697+
* in the output tensor.
698+
* @param type The scalar type of the tensor elements.
699+
* @param dynamism Specifies whether the tensor's shape is static or dynamic.
700+
* @return A TensorPtr instance managing the newly created Tensor.
701+
*/
702+
TensorPtr arange(
703+
executorch::aten::Scalar start,
704+
executorch::aten::Scalar end,
705+
executorch::aten::Scalar step = 1,
706+
std::vector<executorch::aten::SizesType> sizes = {-1},
707+
executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
708+
executorch::aten::TensorShapeDynamism dynamism =
709+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
710+
711+
/**
712+
* Creates a 1D tensor (sizes=[max]) with values from 0 to `end` (exclusive)
713+
* with step size 1.
714+
*
715+
* @param end The ending value of the sequence (exclusive).
716+
* @param type The scalar type of the tensor elements.
717+
* @param dynamism Specifies whether the tensor's shape is static or dynamic.
718+
* @return A TensorPtr instance managing the newly created Tensor.
719+
*/
720+
inline TensorPtr arange(
721+
executorch::aten::Scalar end,
722+
executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
723+
executorch::aten::TensorShapeDynamism dynamism =
724+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
725+
return arange(0, end, 1, {-1}, type, dynamism);
726+
}
727+
686728
} // namespace extension
687729
} // namespace executorch

0 commit comments

Comments
 (0)