diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 32141a9cef7..213adf1c8ab 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -17,6 +17,77 @@ namespace torch { namespace executor { namespace native { +Tensor& sdpa_with_kv_cache_out_no_context( + const Tensor& q_projected, + const Tensor& k_projected, + const Tensor& v_projected, + Tensor& key_cache, + Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output); + +at::Tensor sdpa_with_kv_cache_aten( + const at::Tensor& q_projected, + const at::Tensor& k_projected, + const at::Tensor& v_projected, + at::Tensor& key_cache, + at::Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const std::optional scale); + +Tensor& custom_sdpa_out_no_context( + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output); + +at::Tensor custom_sdpa_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const std::optional scale); + +Tensor& update_cache_out_no_context( + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + Tensor& output); + +at::Tensor update_cache_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos); + Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, diff --git a/extension/llm/custom_ops/op_tile_crop_aot.cpp b/extension/llm/custom_ops/op_tile_crop_aot.cpp index 8a2ee6da626..5aa98ee8d4a 100644 --- a/extension/llm/custom_ops/op_tile_crop_aot.cpp +++ b/extension/llm/custom_ops/op_tile_crop_aot.cpp @@ -17,12 +17,17 @@ namespace executor { namespace native { +Tensor& +tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out); + Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { executorch::aten::RuntimeContext context{}; return tile_crop_out_impl(context, input, tile_size, out); } +at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size); + at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size) { // max_num_tiles = 4, num_channels = 3. auto output = at::empty({4, 3, tile_size, tile_size});