Skip to content

Add REDUCE_MIN to reduce kernel #3113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tensorflow/lite/micro/kernels/reduce.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,15 +28,17 @@ limitations under the License.

namespace tflite {

namespace {

void* InitReduce(TfLiteContext* context, const char* buffer, size_t length) {
void* op_data =
context->AllocatePersistentBuffer(context, sizeof(OpDataReduce));
return new (op_data) OpDataReduce();
}

TfLiteStatus PrepareMax(TfLiteContext* context, TfLiteNode* node) {
return PrepareMaxHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
TfLiteStatus PrepareMinMax(TfLiteContext* context, TfLiteNode* node) {
return PrepareMinMaxHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
}

TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
Expand All @@ -54,17 +56,28 @@ TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
return EvalMaxHelper(context, node, op_data);
}

TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
OpDataReduce* op_data = static_cast<OpDataReduce*>(node->user_data);
return EvalMinHelper(context, node, op_data);
}

TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
return EvalSumHelper(context, node,
static_cast<OpDataReduce*>(node->user_data));
}

} // namespace

TFLMRegistration Register_MEAN() {
return tflite::micro::RegisterOp(InitReduce, PrepareMeanOrSum, EvalMean);
}

TFLMRegistration Register_REDUCE_MAX() {
return tflite::micro::RegisterOp(InitReduce, PrepareMax, EvalMax);
return tflite::micro::RegisterOp(InitReduce, PrepareMinMax, EvalMax);
}

TFLMRegistration Register_REDUCE_MIN() {
return tflite::micro::RegisterOp(InitReduce, PrepareMinMax, EvalMin);
}

TFLMRegistration Register_SUM() {
Expand Down
12 changes: 6 additions & 6 deletions tensorflow/lite/micro/kernels/reduce.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,24 +40,24 @@ struct OpDataReduce {
int num_axis;
};

TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus PrepareMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);

TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);

TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalMinHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);
TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data);

void ReduceResolveAxis(const int* axis_data, int axis_count,
MeanParams* op_params);

TFLMRegistration Register_MEAN();
TFLMRegistration Register_REDUCE_MAX();
TFLMRegistration Register_REDUCE_MIN();
TFLMRegistration Register_SUM();

} // namespace tflite
Expand Down
241 changes: 140 additions & 101 deletions tensorflow/lite/micro/kernels/reduce_common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,8 @@ namespace tflite {
const int kMaxNumberOfAxis = 5;
const int kMaxNumberOfReducedAxis = 2;

namespace {

TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
int32_t* multiplier, int* shift) {
MicroContext* micro_context = GetMicroContext(context);
Expand Down Expand Up @@ -64,8 +66,138 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
void ResolveAxis(const int* axis_data, int axis_count,
tflite::MeanParams* op_params) {
int i = 0;
for (; i < axis_count; ++i) {
op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
}
for (; i < 4; ++i) {
op_params->axis[i] = 1;
}
op_params->axis_count = axis_count;
}

template <typename T>
TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
int* temp_index, int* resolved_axis,
int32_t* temp_sum, OpDataReduce* op_data,
bool compute_sum) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);

bool result = reference_ops::QuantizedMeanOrSumExtraArgs<T, int32_t>(
tflite::micro::GetTensorData<T>(input), op_data->input_zp,
op_data->input_scale, &input->dims->data[0], input->dims->size,
tflite::micro::GetTensorData<T>(output), op_data->output_scale,
op_data->multiplier, op_data->shift, op_data->output_zp,
&output->dims->data[0], output->dims->size,
tflite::micro::GetTensorData<int>(axis), op_data->num_axis,
params->keep_dims, temp_index, resolved_axis, temp_sum, compute_sum);
TF_LITE_ENSURE(context, result);

return kTfLiteOk;
}

template <typename integer_type>
TfLiteStatus EvalIntegerMean(TfLiteContext* context, TfLiteNode* node,
int num_axis, OpDataReduce* op_data,
int* temp_index, int* resolved_axis) {
int32_t* temp_sum = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));

QuantizedMeanOrSum<integer_type>(context, node, temp_index, resolved_axis,
temp_sum, op_data, /*compute_sum=*/false);

return kTfLiteOk;
}

enum MinMaxEvalType { kEvalMin, kEvalMax };

template <typename T>
struct MinMaxReducerParams {
MinMaxReducerParams() = delete;
MinMaxReducerParams(MinMaxEvalType evalType) : type_(evalType){};

constexpr T initialValue() const {
return (type_ == kEvalMin) ? std::numeric_limits<T>::max()
: std::numeric_limits<T>::lowest();
}

// should be able to use "auto" keyword here, but GCC and Clang blow a fuse
T (*compare())(const T, const T) {
if (type_ == kEvalMin) {
return [](const T current, const T in) -> T {
return (in < current) ? in : current;
};
} else {
return [](const T current, const T in) -> T {
return (in > current) ? in : current;
};
}
}

private:
MinMaxEvalType type_;
};

TfLiteStatus EvalMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data, MinMaxEvalType evalType) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);

// Interpret an axis tensor with null dimensions as a scalar
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int* temp_buffer = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
switch (input->type) {
case kTfLiteFloat32: {
MinMaxReducerParams<float> reducer(evalType);
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
reducer.initialValue(), reducer.compare()));
} break;
case kTfLiteInt8: {
MinMaxReducerParams<int8_t> reducer(evalType);
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
static_cast<double>(op_data->output_scale));
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<int8_t>(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
reducer.initialValue(), reducer.compare()));
} break;
default:
MicroPrintf("Only float32 and int8 types are supported.");
return kTfLiteError;
}
return kTfLiteOk;
}

} // namespace

TfLiteStatus PrepareMinMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node, &op_data->multiplier,
&op_data->shift));

Expand Down Expand Up @@ -126,55 +258,6 @@ TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}

void ResolveAxis(const int* axis_data, int axis_count,
tflite::MeanParams* op_params) {
int i = 0;
for (; i < axis_count; ++i) {
op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
}
for (; i < 4; ++i) {
op_params->axis[i] = 1;
}
op_params->axis_count = axis_count;
}

template <typename T>
TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context, TfLiteNode* node,
int* temp_index, int* resolved_axis,
int32_t* temp_sum, OpDataReduce* op_data,
bool compute_sum) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);

bool result = reference_ops::QuantizedMeanOrSumExtraArgs<T, int32_t>(
tflite::micro::GetTensorData<T>(input), op_data->input_zp,
op_data->input_scale, &input->dims->data[0], input->dims->size,
tflite::micro::GetTensorData<T>(output), op_data->output_scale,
op_data->multiplier, op_data->shift, op_data->output_zp,
&output->dims->data[0], output->dims->size,
tflite::micro::GetTensorData<int>(axis), op_data->num_axis,
params->keep_dims, temp_index, resolved_axis, temp_sum, compute_sum);
TF_LITE_ENSURE(context, result);

return kTfLiteOk;
}

template <typename integer_type>
TfLiteStatus EvalIntegerMean(TfLiteContext* context, TfLiteNode* node,
int num_axis, OpDataReduce* op_data,
int* temp_index, int* resolved_axis) {
int32_t* temp_sum = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));

QuantizedMeanOrSum<integer_type>(context, node, temp_index, resolved_axis,
temp_sum, op_data, /*compute_sum=*/false);

return kTfLiteOk;
}

TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
Expand Down Expand Up @@ -238,56 +321,12 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,

TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);
return EvalMinMaxHelper(context, node, op_data, kEvalMax);
}

// Interpret an axis tensor with null dimensions as a scalar
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int* temp_buffer = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
std::numeric_limits<float>::lowest(),
[](const float current, const float in) -> float {
return (in > current) ? in : current;
}));
break;
case kTfLiteInt8:
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
static_cast<double>(op_data->output_scale));
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<int8_t>(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
std::numeric_limits<int8_t>::lowest(),
[](const int8_t current, const int8_t in) -> int8_t {
return (in > current) ? in : current;
}));
break;
default:
MicroPrintf("Only float32 and int8 types are supported.");
return kTfLiteError;
}
return kTfLiteOk;
TfLiteStatus EvalMinHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
return EvalMinMaxHelper(context, node, op_data, kEvalMin);
}

TfLiteStatus EvalSumHelper(TfLiteContext* context, TfLiteNode* node,
Expand Down
Loading
Loading