From 6a06e5a1f92b8a4d85d0e024aaeae594effd9965 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 3 Apr 2025 16:23:12 -0700 Subject: [PATCH] [Executorch][sdpa] Add accidentaly removed flash attentiona args check as the title Differential Revision: [D71370594](https://our.internmc.facebook.com/intern/diff/D71370594/) [ghstack-poisoned] --- extension/llm/custom_ops/op_sdpa.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 81c02334bbc..7112489b769 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -294,6 +294,13 @@ Tensor& custom_sdpa_out( output, "attn_mask and is_causal cannot be set at the same time"); + ET_KERNEL_CHECK_MSG( + ctx, + validate_flash_attention_args(q, k, v, attn_mask), + InvalidArgument, + output, + "Invalid arguments"); + ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor"); const int64_t seq_len = q.size(1);