diff --git a/kernels/quantized/cpu/op_embedding.cpp b/kernels/quantized/cpu/op_embedding.cpp index b297d91870a..c43755ed3da 100644 --- a/kernels/quantized/cpu/op_embedding.cpp +++ b/kernels/quantized/cpu/op_embedding.cpp @@ -153,6 +153,22 @@ void embedding_byte_per_channel( for (int i = 0; i < indices.numel(); i++) { int64_t index = indices_ptr[i]; + + // Check if index is out of bounds for both weight and weight_scales + ET_CHECK_MSG( + index >= 0 && index < weight.size(0), + "Index out of bounds for weight: index %" PRId64 + " must be in range [0, %zd)", + index, + weight.size(0)); + + ET_CHECK_MSG( + index >= 0 && index < weight_scales.size(0), + "Index out of bounds for weight_scales: index %" PRId64 + " must be in range [0, %zd)", + index, + weight_scales.size(0)); + // If using groupwise embedding int32_t qparams_index = index * num_groups_per_channel; CTYPE_PARAMS zp = 0.0; diff --git a/kernels/quantized/test/op_embedding_test.cpp b/kernels/quantized/test/op_embedding_test.cpp index 6c949bd6e69..68359f5e45b 100644 --- a/kernels/quantized/test/op_embedding_test.cpp +++ b/kernels/quantized/test/op_embedding_test.cpp @@ -373,3 +373,38 @@ TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) { out), ""); } + +TEST(OpQuantizedEmbeddingTest, TestOutOfBoundsIndex) { + et_pal_init(); + TensorFactory tf; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + // Create a weight tensor with 3 rows + TensorFactory tfo; + Tensor qweight = + tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12}); + + // Create weight_scales with the same number of rows + Tensor weight_scales = tf.make({3, 1}, {0.5, 1.0, 1.5}); + Tensor weight_zero_points = tf.make({3, 1}, {1, 5, 7}); + + // Create indices with an out-of-bounds index (3, which is >= weight.size(0)) + Tensor indices = tf_l.make({2}, {1, 3}); + + Tensor out = tf.zeros({2, 4}); + + // Expect death when accessing an out-of-bounds index + ET_EXPECT_DEATH( + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + "Index out of bounds for weight"); +}