diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 81c02334bbc..7112489b769 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -294,6 +294,13 @@ Tensor& custom_sdpa_out( output, "attn_mask and is_causal cannot be set at the same time"); + ET_KERNEL_CHECK_MSG( + ctx, + validate_flash_attention_args(q, k, v, attn_mask), + InvalidArgument, + output, + "Invalid arguments"); + ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor"); const int64_t seq_len = q.size(1);