diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 1cc3c4af4b0..8a824444814 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -84,99 +84,38 @@ Tensor& clamp_out( Error err = resize_tensor(out, in.sizes()); ET_CHECK_MSG(err == Error::Ok, "Could not resize output"); - ScalarType in_type = in.scalar_type(); - ScalarType min_type = in_type; - ScalarType max_type = in_type; - ScalarType common_type = in_type; - ScalarType out_type = out.scalar_type(); - - bool has_min = min_opt.has_value(); - if (has_min) { - min_type = utils::get_scalar_dtype(min_opt.value()); - common_type = utils::promote_type_with_scalar(common_type, min_opt.value()); - } - bool has_max = max_opt.has_value(); - if (has_max) { - max_type = utils::get_scalar_dtype(max_opt.value()); - common_type = utils::promote_type_with_scalar(common_type, max_opt.value()); - } - - ET_CHECK_MSG( - has_min || has_max, "At least one of 'min' or 'max' must not be None"); + ET_CHECK_SAME_SHAPE_AND_DTYPE2(in, out); - ET_CHECK(common_type == out_type); - - ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "clamp", CTYPE, [&]() { // Extract optional min value - CTYPE_OUT min = 0; + CTYPE min = 0; + bool has_min = min_opt.has_value(); if (has_min) { - ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() { - CTYPE_MIN min_val = 0; - ET_EXTRACT_SCALAR(min_opt.value(), min_val); - if (isIntegralType(out_type, /*includeBool=*/false)) { - if (static_cast(min_val) < - std::numeric_limits::lowest() || - static_cast(min_val) > - std::numeric_limits::max()) { - ET_CHECK_MSG(false, "minimum value out of bounds"); - } - } - if (isFloatingType(out_type)) { - if (std::isfinite(min_val) && - (static_cast(min_val) < - std::numeric_limits::lowest() || - static_cast(min_val) > - std::numeric_limits::max())) { - ET_CHECK_MSG(false, "minimum value out of bounds"); - } - } - min = static_cast(min_val); - }); + bool ok = utils::extract_scalar(min_opt.value(), &min); + ET_CHECK_MSG(ok, "Invalid min value: wrong type or out of range"); } - // Extract optional max value - CTYPE_OUT max = 0; + CTYPE max = 0; + bool has_max = max_opt.has_value(); if (has_max) { - ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() { - CTYPE_MAX max_val = 0; - ET_EXTRACT_SCALAR(max_opt.value(), max_val); - if (isIntegralType(out_type, /*includeBool=*/false)) { - if (static_cast(max_val) < - std::numeric_limits::lowest() || - static_cast(max_val) > - std::numeric_limits::max()) { - ET_CHECK_MSG(false, "maximum value out of bounds"); - } - } - if (isFloatingType(out_type)) { - if (std::isfinite(max_val) && - (static_cast(max_val) < - std::numeric_limits::lowest() || - static_cast(max_val) > - std::numeric_limits::max())) { - ET_CHECK_MSG(false, "maximum value out of bounds"); - } - } - max = static_cast(max_val); - }); + bool ok = utils::extract_scalar(max_opt.value(), &max); + ET_CHECK_MSG(ok, "Invalid max value: wrong type or out of range"); } - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() { - apply_unary_map_fn( - [has_min, min, has_max, max](const CTYPE_IN val_in) { - CTYPE_OUT val_out = static_cast(val_in); - if (has_min) { - val_out = max_override(val_out, min); - } - if (has_max) { - val_out = min_override(val_out, max); - } - return val_out; - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); - }); + apply_unary_map_fn( + [has_min, min, has_max, max](const CTYPE val_in) { + CTYPE val_out = val_in; + if (has_min) { + val_out = max_override(val_out, min); + } + if (has_max) { + val_out = min_override(val_out, max); + } + return val_out; + }, + in.const_data_ptr(), + out.mutable_data_ptr(), + in.numel()); }); return out; diff --git a/kernels/test/op_clamp_test.cpp b/kernels/test/op_clamp_test.cpp index 08d898733e1..b505f91bc85 100644 --- a/kernels/test/op_clamp_test.cpp +++ b/kernels/test/op_clamp_test.cpp @@ -303,12 +303,12 @@ TEST(OpClampOutTest, ByteTensorFloatingPointClampDies) { #ifndef USE_ATEN_LIB TEST(OpClampOutTest, IntTensorTooSmallClampDies) { - // Cannot be represented by a int32_t. + // Cannot be represented by a uint32_t. expect_bad_clamp_value_dies(-2147483649); } TEST(OpClampOutTest, IntTensorTooLargeClampDies) { - // Cannot be represented by a int32_t. + // Cannot be represented by a uint32_t. expect_bad_clamp_value_dies(2147483648); } #endif