@@ -83,7 +83,7 @@ Tensor& add_out(
83
83
Tensor& out) {
84
84
ET_KERNEL_CHECK (
85
85
ctx,
86
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
86
+ torch::executor:: resize_to_broadcast_target_size (a, b, out) == Error::Ok,
87
87
InvalidArgument,
88
88
out);
89
89
@@ -93,25 +93,36 @@ Tensor& add_out(
93
93
InvalidArgument,
94
94
out);
95
95
ET_KERNEL_CHECK (
96
- ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
96
+ ctx,
97
+ executorch::runtime::tensors_have_same_dim_order (a, b, out),
98
+ InvalidArgument,
99
+ out);
97
100
98
101
ScalarType a_type = a.scalar_type ();
99
102
ScalarType b_type = b.scalar_type ();
100
- ScalarType alpha_type =
101
- torch::executor::native::utils::get_scalar_dtype (alpha);
102
- ScalarType common_type = promoteTypes (a_type, b_type, /* half_to_float*/ true );
103
+ ScalarType alpha_type =
104
+ torch::executor::native::utils::get_scalar_dtype (alpha);
105
+ ScalarType common_type =
106
+ executorch::runtime::promoteTypes (a_type, b_type, /* half_to_float*/ true );
103
107
ScalarType out_type = out.scalar_type ();
104
108
105
- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
106
109
ET_KERNEL_CHECK (
107
- ctx, check_alpha_type (alpha_type, common_type), InvalidArgument, out);
108
-
110
+ ctx,
111
+ executorch::runtime::canCast (common_type, out_type),
112
+ InvalidArgument,
113
+ out);
114
+ ET_KERNEL_CHECK (
115
+ ctx,
116
+ torch::executor::check_alpha_type (alpha_type, common_type),
117
+ InvalidArgument,
118
+ out);
119
+
109
120
float alpha_val;
110
121
torch::executor::native::utils::extract_scalar (alpha, &alpha_val);
111
122
112
123
constexpr auto name = " add.out" ;
113
124
constexpr int kNnlibMaxDim = 4 ; /* fallback if broadcast and dim > 4 */
114
-
125
+
115
126
int a_dim = a.dim (), b_dim = b.dim (), out_dim = out.dim ();
116
127
bool optimized = 1 ;
117
128
/* find broadcast*/
@@ -124,51 +135,48 @@ Tensor& add_out(
124
135
if ((out_type != ScalarType::Float) || (alpha_val != 1.0 ))
125
136
optimized = 0 ;
126
137
127
- if ((a_dim == 0 ) || (b_dim == 0 ) )
138
+ if ((a_dim == 0 ) || (b_dim == 0 ))
128
139
optimized = 0 ;
129
140
130
141
if ((broadcast == 1 ) && (max_dim > kNnlibMaxDim ))
131
142
optimized = 0 ;
132
143
133
-
134
144
if (optimized) {
135
- const float * const a_data = a.const_data_ptr <float >();
136
- const float * const b_data = b.const_data_ptr <float >();
137
- float * const out_data = out.mutable_data_ptr <float >();
138
-
139
- if (broadcast == 1 ) {
140
- int out_shape[kNnlibMaxDim ];
141
- int inp1_shape[kNnlibMaxDim ];
142
- int inp2_shape[kNnlibMaxDim ];
143
-
144
- for (int i = 0 ; i < kNnlibMaxDim ; i++) {
145
- out_shape[i] = 1 ;
146
- inp1_shape[i] = 1 ;
147
- inp2_shape[i] = 1 ;
148
- }
149
-
150
- int off_o = kNnlibMaxDim - out.dim ();
151
- int off_a = kNnlibMaxDim - a.dim ();
152
- int off_b = kNnlibMaxDim - b.dim ();
153
-
154
- for (int i = 0 ; i < out.dim (); i++)
155
- out_shape[i+off_o] = out.size (i);
156
- for (int i = 0 ; i < a.dim (); i++)
157
- inp1_shape[i+off_a] = a.size (i);
158
- for (int i = 0 ; i < b.dim (); i++)
159
- inp2_shape[i+off_b] = b.size (i);
160
-
161
- xa_nn_elm_add_broadcast_4D_f32xf32_f32 (
162
- out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape);
163
- }
164
- else
165
- {
166
- xa_nn_elm_add_f32xf32_f32 (out_data, a_data, b_data, out.numel ());
145
+ const float * const a_data = a.const_data_ptr <float >();
146
+ const float * const b_data = b.const_data_ptr <float >();
147
+ float * const out_data = out.mutable_data_ptr <float >();
148
+
149
+ if (broadcast == 1 ) {
150
+ int out_shape[kNnlibMaxDim ];
151
+ int inp1_shape[kNnlibMaxDim ];
152
+ int inp2_shape[kNnlibMaxDim ];
153
+
154
+ for (int i = 0 ; i < kNnlibMaxDim ; i++) {
155
+ out_shape[i] = 1 ;
156
+ inp1_shape[i] = 1 ;
157
+ inp2_shape[i] = 1 ;
167
158
}
168
159
169
- return out;
160
+ int off_o = kNnlibMaxDim - out.dim ();
161
+ int off_a = kNnlibMaxDim - a.dim ();
162
+ int off_b = kNnlibMaxDim - b.dim ();
163
+
164
+ for (int i = 0 ; i < out.dim (); i++)
165
+ out_shape[i + off_o] = out.size (i);
166
+ for (int i = 0 ; i < a.dim (); i++)
167
+ inp1_shape[i + off_a] = a.size (i);
168
+ for (int i = 0 ; i < b.dim (); i++)
169
+ inp2_shape[i + off_b] = b.size (i);
170
+
171
+ xa_nn_elm_add_broadcast_4D_f32xf32_f32 (
172
+ out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape);
173
+ } else {
174
+ xa_nn_elm_add_f32xf32_f32 (out_data, a_data, b_data, out.numel ());
175
+ }
176
+
177
+ return out;
170
178
}
171
-
179
+
172
180
ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
173
181
ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
174
182
using CTYPE_IN = typename torch::executor::
@@ -191,7 +199,6 @@ Tensor& add_out(
191
199
return out;
192
200
}
193
201
194
-
195
202
} // namespace native
196
203
} // namespace HiFi
197
204
} // namespace impl
0 commit comments