From 58e27925138bf691c0f03ba0a8c73f4ead7840b0 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 23 Jun 2025 22:07:40 -0700 Subject: [PATCH] [llm] Support different shape of input_pos For huggingface models, `forward()` is taking `tokens` as well as `cache_positions`, which is a list of cache indices. This is different than the .pte files `export_llama` gives, which are taking `tokens` and `input_pos` where `input_pos` is a scalar tensor. This PR adds support inside `text_decoder_runner.cpp` to handle both shapes of `input_pos`/`cache_positions`. To make the logic more generic without relying on extra metadata, here I'm adding the logic of inspecting method meta and input tensor info, to make a decision if we want to feed in `input_pos` or `cache_position`. Differential Revision: [D77203700](https://our.internmc.facebook.com/intern/diff/D77203700/) [ghstack-poisoned] --- extension/llm/runner/test/CMakeLists.txt | 2 +- extension/llm/runner/test/targets.bzl | 15 ++ .../runner/test/test_text_decoder_runner.cpp | 199 ++++++++++++++++++ .../llm/runner/test/test_text_llm_runner.cpp | 7 +- extension/llm/runner/text_decoder_runner.cpp | 44 +++- extension/llm/runner/text_decoder_runner.h | 5 +- extension/llm/runner/text_llm_runner.cpp | 3 +- extension/llm/runner/text_prefiller.cpp | 13 +- extension/llm/runner/text_token_generator.h | 5 +- runtime/executor/method_meta.h | 2 +- test/models/export_program.py | 49 +++++ test/models/targets.bzl | 3 + 12 files changed, 317 insertions(+), 30 deletions(-) create mode 100644 extension/llm/runner/test/test_text_decoder_runner.cpp diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 15b4d005f9d..78dcb25bcc5 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -18,7 +18,7 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp - test_text_prefiller.cpp + test_text_prefiller.cpp test_text_decoder_runner.cpp ) et_cxx_test( diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index 8bc3d4cc100..fdc2a484235 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -36,3 +36,18 @@ def define_common_targets(): "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], ) + + runtime.cxx_test( + name = "test_text_decoder_runner", + srcs = ["test_text_decoder_runner.cpp"], + deps = [ + "//executorch/extension/llm/runner:runner_lib", + "//executorch/kernels/portable:generated_lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + env = { + "KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])", + "KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])", + "NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])", + } + ) diff --git a/extension/llm/runner/test/test_text_decoder_runner.cpp b/extension/llm/runner/test/test_text_decoder_runner.cpp new file mode 100644 index 00000000000..c9a8de271f1 --- /dev/null +++ b/extension/llm/runner/test/test_text_decoder_runner.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using executorch::extension::Module; +using executorch::extension::TensorPtr; +using executorch::extension::llm::TextDecoderRunner; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::Result; +using executorch::runtime::testing::TensorFactory; + +// Mock Module class for testing +class MockModule : public Module { + public: + MockModule() : Module("") {} +}; + +class TextDecoderRunnerTest : public Test { + protected: + void SetUp() override { + mock_module_ = std::make_unique(); + runner_ = std::make_unique(mock_module_.get()); + } + + std::unique_ptr mock_module_; + std::unique_ptr runner_; +}; + +// Test logits_to_token() method with Float tensor +TEST_F(TextDecoderRunnerTest, LogitsToTokenFloat) { + TensorFactory tf_float; + auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + + // Call logits_to_token with temperature 0 (deterministic) + int32_t token = runner_->logits_to_token(logits, 0.0f); + + // With temperature 0, should return the argmax (index 2) + EXPECT_EQ(token, 2); +} + +// Test logits_to_token() method with 3D tensor (batch, seq_length, vocab_size) +TEST_F(TextDecoderRunnerTest, LogitsToToken3D) { + TensorFactory tf_float; + // Shape: [1, 2, 4] - batch=1, seq_length=2, vocab_size=4 + auto logits = tf_float.make( + {1, 2, 4}, + { + 0.1f, + 0.2f, + 0.3f, + 0.4f, // First sequence position + 0.5f, + 0.6f, + 0.9f, + 0.8f // Second sequence position (last) + }); + + // Call logits_to_token with temperature 0 (deterministic) + int32_t token = runner_->logits_to_token(logits, 0.0f); + + // Should use the last sequence position and return argmax (index 2) + EXPECT_EQ(token, 2); +} + +// Test logits_to_token() method with Half tensor +TEST_F(TextDecoderRunnerTest, LogitsToTokenHalf) { + TensorFactory tf_half; + auto logits = tf_half.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + + // Call logits_to_token with temperature 0 (deterministic) + int32_t token = runner_->logits_to_token(logits, 0.0f); + + // With temperature 0, should return the argmax (index 2) + EXPECT_EQ(token, 2); +} + +// Test logits_to_token() method with BFloat16 tensor +TEST_F(TextDecoderRunnerTest, LogitsToTokenBFloat16) { + TensorFactory tf_bfloat16; + auto logits = tf_bfloat16.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + + // Call logits_to_token with temperature 0 (deterministic) + int32_t token = runner_->logits_to_token(logits, 0.0f); + + // With temperature 0, should return the argmax (index 2) + EXPECT_EQ(token, 2); +} + +// Test logits_to_token() method with non-zero temperature +TEST_F(TextDecoderRunnerTest, LogitsToTokenWithTemperature) { + TensorFactory tf_float; + auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + + // Call logits_to_token with temperature > 0 (stochastic) + int32_t token = runner_->logits_to_token(logits, 1.0f); + + // With temperature > 0, result should be within valid range + EXPECT_GE(token, 0); + EXPECT_LT(token, 4); +} + +// Test step() method with all available PTE models +TEST_F(TextDecoderRunnerTest, StepWithAllModels) { + // List of all environment variables for PTE models + std::vector> env_vars = { + {"KVCACHE_CACHE_POS", "KVCACHE_CACHE_POS"}, + {"KVCACHE_INPUT_POS", "KVCACHE_INPUT_POS"}, + {"NO_KVCACHE", "NO_KVCACHE"}}; + + // Check if any environment variables are set up front + bool any_env_set = false; + for (const auto& [model_name, env_var] : env_vars) { + if (std::getenv(env_var)) { + any_env_set = true; + break; + } + } + + // Skip test if no environment variables are set + if (!any_env_set) { + GTEST_SKIP() << "No PTE model environment variables were set"; + } + + bool any_model_tested = false; + + // Loop through all available models + for (const auto& [model_name, env_var] : env_vars) { + const char* model_path = std::getenv(env_var); + if (!model_path) { + continue; // Skip if environment variable not set + } + + SCOPED_TRACE( + "Testing model: " + model_name + " from " + std::string(model_path)); + + // Load the model + auto module = std::make_unique(model_path); + auto load_result = module->load(); + if (load_result != Error::Ok) { + ADD_FAILURE() << "Failed to load model " << model_name << " from " + << model_path << " with error: " << (int)load_result; + continue; + } + + // Create TextDecoderRunner + TextDecoderRunner runner(module.get()); + auto runner_load_result = runner.load(); + ASSERT_EQ(runner_load_result, Error::Ok) + << "Failed to load runner for " << model_name; + + // Verify method is loaded + EXPECT_TRUE(runner.is_method_loaded()) + << "Method not loaded for " << model_name; + + // Create input tensor pointer + + TensorFactory tf_long; + auto input_tokens_ = + tf_long.make({1, 3}, {50, 7, 11}); // Single token input + + auto input_ptr = std::make_shared(input_tokens_); + int64_t start_pos = 0; + + // Call step() and verify result is ok + auto result = runner.step(input_ptr, start_pos); + ASSERT_TRUE(result.ok()) << "step() failed for " << model_name + << " with error: " << (int)result.error(); + + // Verify output tensor is valid + auto output_tensor = result.get(); + EXPECT_GT(output_tensor.numel(), 0) + << "Output tensor empty for " << model_name; + + // Test logits_to_token works + int32_t token = runner.logits_to_token(output_tensor, 0.0f); + EXPECT_GE(token, 0) << "Invalid token for " << model_name; + + any_model_tested = true; + } + + // This should not happen since we checked environment variables up front + ASSERT_TRUE(any_model_tested) + << "No models were tested despite environment variables being set"; +} diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 02f04a69b38..6896c56e961 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -63,11 +63,11 @@ class MockModule : public ::executorch::extension::Module { class MockTextDecoderRunner : public TextDecoderRunner { public: - MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {} + MockTextDecoderRunner() : TextDecoderRunner(nullptr) {} MOCK_METHOD( Result, step, - (executorch::extension::TensorPtr&, executorch::extension::TensorPtr&), + (executorch::extension::TensorPtr&, int64_t), ()); MOCK_METHOD(bool, is_method_loaded, (), ()); MOCK_METHOD(Result, prefill, (std::vector&, int64_t), ()); @@ -134,8 +134,7 @@ class RunnerTest : public Test { std::unique_ptr createMockTextDecoderRunner() { auto text_decoder_runner = std::make_unique(); ON_CALL(*text_decoder_runner, step) - .WillByDefault([&](executorch::extension::TensorPtr&, - executorch::extension::TensorPtr&) { + .WillByDefault([&](executorch::extension::TensorPtr&, int64_t) { return Result(tensor); }); ON_CALL(*text_decoder_runner, is_method_loaded()) diff --git a/extension/llm/runner/text_decoder_runner.cpp b/extension/llm/runner/text_decoder_runner.cpp index 8705dfeb842..eac173692e7 100644 --- a/extension/llm/runner/text_decoder_runner.cpp +++ b/extension/llm/runner/text_decoder_runner.cpp @@ -21,18 +21,52 @@ namespace llm { // NOTE: we observed ~2x loading performance increase on iPhone 15 // and a ~5% improvement on Galaxy S22 by switching to // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors. -TextDecoderRunner::TextDecoderRunner(Module* module, bool use_kv_cache) - : module_(module), use_kv_cache_(use_kv_cache) {} +TextDecoderRunner::TextDecoderRunner(Module* module) : module_(module) {} // This function is functional, meaning it shouldn't modify any state of the // input. It should be safe to call multiple times with the same inputs. The // outer loop (call site) is responsible for managing state. ::executorch::runtime::Result TextDecoderRunner::step( TensorPtr& tokens, - TensorPtr& start_pos) { + int64_t start_pos) { // ET_LOG(Info, "Input token %" PRIu64, input_token); - if (use_kv_cache_) { - auto outputs_res = module_->forward({tokens, start_pos}); + auto method_meta = ET_UNWRAP(module_->method_meta("forward")); + // If only 1 input, we are not using kv cache + bool use_kv_cache = method_meta.num_inputs() > 1; + + if (use_kv_cache) { + // Size of the second argument. This could be either input_pos or + // cache_positions + + // Check if we are using cache positions instead of input pos. + auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1)); + // For input_pos, numel is 1, for cache_positions, numel is max_seq_len + auto sizes = second_input_info.sizes(); + auto numel = 1; + std::vector<::executorch::aten::SizesType> sizes_vec; + for (const auto& size : sizes) { + sizes_vec.emplace_back(size); + numel *= size; + } + // Assuming the last dimension is the one with the variable token length + sizes_vec[sizes_vec.size() - 1] = -1; + TensorPtr start_pos_tensor; + if (numel > 1) { + // Assuming model is exported with cache_positions, create a tensor with + // the same size as cache_positions + start_pos_tensor = arange( + start_pos, + start_pos + tokens->numel(), + 1, + sizes_vec, + ::executorch::aten::ScalarType::Long); + } else { + // Assuming model is exported with input_pos, create a tensor with size 1 + start_pos_tensor = + from_blob(&start_pos, {1}, ::executorch::aten::ScalarType::Long); + } + ET_LOG(Info, "Start pos tensor numel: %zu", start_pos_tensor->numel()); + auto outputs_res = module_->forward({tokens, start_pos_tensor}); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_CHECK_MSG( outputs_res.get().size() == 1, diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 6c1256c6b90..0dd6ba578d3 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -21,7 +21,7 @@ namespace llm { class ET_EXPERIMENTAL TextDecoderRunner { public: - TextDecoderRunner(Module* module, bool use_kv_cache); + TextDecoderRunner(Module* module); virtual ~TextDecoderRunner() = default; @@ -34,7 +34,7 @@ class ET_EXPERIMENTAL TextDecoderRunner { */ virtual ::executorch::runtime::Result step( TensorPtr& input, - TensorPtr& start_pos); + int64_t start_pos); /** * Load the Module for text decode purpose. @@ -101,7 +101,6 @@ class ET_EXPERIMENTAL TextDecoderRunner { * Module remains valid for the duration of TextDecoderRunner's usage. */ Module* module_; - bool use_kv_cache_; bool should_stop_{false}; }; diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 6a0cfd45044..b93988cffd5 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -393,8 +393,7 @@ std::unique_ptr create_text_llm_runner( // Create text_decoder_runner. Use a shared_ptr so that it can be shared with // TextPrefiller and TextTokenGenerator - auto text_decoder_runner = std::make_unique( - module.get(), metadata.at(kUseKVCache)); + auto text_decoder_runner = std::make_unique(module.get()); // Create text_prefiller auto text_prefiller = std::make_unique( diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 64f3fee167b..de092b6b05d 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -86,10 +86,7 @@ ::executorch::runtime::Result TextPrefiller::prefill_chunk( {1, num_prompt_tokens}, executorch::aten::ScalarType::Long); - auto start_pos_tensor = - from_blob(&start_pos, {1}, executorch::aten::ScalarType::Long); - - auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor); + auto outputs_res = text_decoder_runner_->step(tokens, start_pos); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_LOG( @@ -106,13 +103,10 @@ ::executorch::runtime::Result TextPrefiller::prefill_chunk( auto tokens = from_blob(&cur_token, {1, 1}, executorch::aten::ScalarType::Long); - auto start_pos_tensor = - from_blob(&start_pos, {1}, executorch::aten::ScalarType::Long); - // run the first token and get back logits tensor. Assuming the first token // is bos so don't callback. auto logits_tensor = - ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor)); + ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos)); pos += 1; // start the loop from index 1 start_pos += 1; @@ -122,8 +116,7 @@ ::executorch::runtime::Result TextPrefiller::prefill_chunk( // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) cur_token = prompt_tokens[pos]; - logits_tensor = - ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor)); + logits_tensor = ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos)); pos++; start_pos++; diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 839ad195c7e..1a05921ed3a 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -78,16 +78,13 @@ class ET_EXPERIMENTAL TextTokenGenerator { // initialize tensor wrappers auto tokens_managed = from_blob( token_data.data(), token_shape, executorch::aten::ScalarType::Long); - auto start_pos_managed = - from_blob(&pos, {1}, executorch::aten::ScalarType::Long); should_stop_ = false; // Generate our tokens while (pos < start_pos + max_new_tokens) { // Run the model - auto logits_res = - text_decoder_runner_->step(tokens_managed, start_pos_managed); + auto logits_res = text_decoder_runner_->step(tokens_managed, pos); ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); executorch::aten::Tensor& logits_tensor = logits_res.get(); diff --git a/runtime/executor/method_meta.h b/runtime/executor/method_meta.h index 1b3be75ef17..eb7f93e881e 100644 --- a/runtime/executor/method_meta.h +++ b/runtime/executor/method_meta.h @@ -44,7 +44,7 @@ class TensorInfo final { /** * Returns the sizes of the tensor. */ - Span sizes() const; + Span sizes() const; /** * Returns the dim order of the tensor. diff --git a/test/models/export_program.py b/test/models/export_program.py index e13b63eaf74..dac42ecee1c 100644 --- a/test/models/export_program.py +++ b/test/models/export_program.py @@ -213,6 +213,55 @@ def export_state_names(): return True +# Mimicking LLM with forward taking tokens and input_pos +class ModuleKVCacheInputPos(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x, input_pos): + return (self.linear(x.to(torch.float)).to(torch.long) + input_pos).to( + torch.float + ) + + def get_random_inputs(self): + return ( + torch.randint(100, [1, 3], dtype=torch.long), + torch.tensor([0], dtype=torch.long), + ) + + +# Mimicking LLM with forward taking tokens and cache_positions +class ModuleKVCacheCachePos(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x, cache_positions): + return (self.linear(x.to(torch.float)).to(torch.long) + cache_positions).to( + torch.float + ) + + def get_random_inputs(self): + return ( + torch.randint(100, [1, 3], dtype=torch.long), + torch.arange(3, dtype=torch.long), + ) + + +# Mimicking LLM with forward taking only tokens +class ModuleNoKVCache(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return self.linear(x.to(torch.float)) + + def get_random_inputs(self): + return (torch.randint(100, [1, 3], dtype=torch.long),) + + # # Main logic. # diff --git a/test/models/targets.bzl b/test/models/targets.bzl index 391ce230ab8..769fcb65ccd 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -63,7 +63,10 @@ def define_common_targets(): "ModuleAddHalf", "ModuleAddMul", "ModuleBasic", + "ModuleKVCacheCachePos", + "ModuleKVCacheInputPos", "ModuleMultipleEntry", + "ModuleNoKVCache", "ModuleIndex", "ModuleDynamicCatUnallocatedIO", "ModuleSimpleTrain",