diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index d240b9f83bc..525199a6f78 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -27,6 +27,8 @@ bool check_tensor_dtype( return executorch::runtime::tensor_is_floating_type(t); case SupportedTensorDtypes::INTB: return executorch::runtime::tensor_is_integral_type(t, true); + case SupportedTensorDtypes::BOOL: + return executorch::runtime::tensor_is_type(t, ScalarType::Bool); case SupportedTensorDtypes::BOOL_OR_BYTE: return (executorch::runtime::tensor_is_type( t, ScalarType::Bool, ScalarType::Byte)); diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 1e7901c80b2..15732219c8f 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -72,6 +72,16 @@ load_to_compute_fn get_load_to_compute_fn_intb(const Tensor& t) { return result; } +template +load_to_compute_fn get_load_to_compute_fn_bool(const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::load_and_convert; +} + template load_to_compute_fn get_load_to_compute_fn_bool_or_byte( const Tensor& t) { @@ -165,6 +175,17 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn_intb( return result; } +template +store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool( + const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::convert_and_store; +} + template store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) { @@ -219,6 +240,7 @@ enum class SupportedTensorDtypes { REALHBF16, FLOATHBF16, INTB, + BOOL, BOOL_OR_BYTE, // DEPRECATED: not likely to be correct; use SAME_AS_COMMON. SAME_AS_COMPUTE, @@ -240,6 +262,8 @@ load_to_compute_fn get_load_to_compute_fn_impl( return get_load_to_compute_fn_realhbf16(t); case SupportedTensorDtypes::INTB: return get_load_to_compute_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_load_to_compute_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_compute_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: @@ -271,6 +295,8 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn( t); case SupportedTensorDtypes::INTB: return get_store_compute_to_tensor_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_store_compute_to_tensor_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_compute_to_tensor_fn_bool_or_byte< CTYPE_COMPUTE, @@ -318,12 +344,14 @@ bool check_tensor_dtype( const ScalarType compute_type); /// Return the one output type we are willing to emit specialized code -/// to handle, given a compute type of CTYPE_COMMON and supported +/// to handle, given a compute type of CTYPE_COMPUTE and supported /// output types of out_dtypes. template inline constexpr ScalarType specialized_output_scalar_type( SupportedTensorDtypes out_dtypes) { switch (out_dtypes) { + case SupportedTensorDtypes::BOOL: + return ScalarType::Bool; case SupportedTensorDtypes::BOOL_OR_BYTE: return ScalarType::Bool; case SupportedTensorDtypes::REALHBBF16: