Skip to content

[llm] Support different shape of input_pos #11869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/models/llava/runner/llava_text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,28 @@
#pragma once

#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/tensor/tensor.h>

namespace example {

class ET_EXPERIMENTAL LlavaTextDecoderRunner
: public executorch::extension::llm::TextDecoderRunner {
public:
explicit LlavaTextDecoderRunner(executorch::extension::Module* module)
: TextDecoderRunner(module, true) {}
: TextDecoderRunner(module) {}

inline executorch::runtime::Result<executorch::aten::Tensor> step(
executorch::extension::TensorPtr& tokens,
executorch::extension::TensorPtr& start_pos) override {
int64_t start_pos) override {
// run token embedding
auto token_embedding_outputs =
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens));

auto start_pos_tensor = ::executorch::extension::from_blob(
&start_pos, {1}, executorch::aten::ScalarType::Long);
// run text model
auto outputs_res = ET_UNWRAP(module_->execute(
kTextModelMethod, {start_pos, token_embedding_outputs[0]}));
kTextModelMethod, {start_pos_tensor, token_embedding_outputs[0]}));

ET_CHECK_MSG(
outputs_res.size() == 1,
Expand Down
1 change: 1 addition & 0 deletions extension/llm/runner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ add_subdirectory(
set(runner_deps executorch_core extension_module extension_tensor tokenizers)

target_link_libraries(extension_llm_runner PUBLIC ${runner_deps})
set_target_properties(extension_llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON)

target_include_directories(
extension_llm_runner
Expand Down
1 change: 1 addition & 0 deletions extension/llm/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def define_common_targets():
],
exported_deps = [
":stats",
"//executorch/kernels/portable/cpu/util:arange_util" + aten_suffix,
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
"//executorch/extension/tensor:tensor" + aten_suffix,
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/runner/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)

set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp
test_text_prefiller.cpp
test_text_prefiller.cpp test_text_decoder_runner.cpp
)

et_cxx_test(
Expand Down
17 changes: 16 additions & 1 deletion extension/llm/runner/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,22 @@
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
oncall("executorch")

define_common_targets()

runtime.cxx_test(
name = "test_text_decoder_runner",
srcs = ["test_text_decoder_runner.cpp"],
deps = [
"//executorch/extension/llm/runner:runner_lib",
"//executorch/kernels/portable:generated_lib",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
],
env = {
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",
"KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])",
"NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])",
}
)
199 changes: 199 additions & 0 deletions extension/llm/runner/test/test_text_decoder_runner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
*/

#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <cstdlib>

using namespace ::testing;
using executorch::extension::Module;
using executorch::extension::TensorPtr;
using executorch::extension::llm::TextDecoderRunner;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::Result;
using executorch::runtime::testing::TensorFactory;

// Mock Module class for testing
class MockModule : public Module {
public:
MockModule() : Module("") {}
};

class TextDecoderRunnerTest : public Test {
protected:
void SetUp() override {
mock_module_ = std::make_unique<MockModule>();
runner_ = std::make_unique<TextDecoderRunner>(mock_module_.get());
}

std::unique_ptr<MockModule> mock_module_;
std::unique_ptr<TextDecoderRunner> runner_;
};

// Test logits_to_token() method with Float tensor
TEST_F(TextDecoderRunnerTest, LogitsToTokenFloat) {
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});

// Call logits_to_token with temperature 0 (deterministic)
int32_t token = runner_->logits_to_token(logits, 0.0f);

// With temperature 0, should return the argmax (index 2)
EXPECT_EQ(token, 2);
}

// Test logits_to_token() method with 3D tensor (batch, seq_length, vocab_size)
TEST_F(TextDecoderRunnerTest, LogitsToToken3D) {
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
// Shape: [1, 2, 4] - batch=1, seq_length=2, vocab_size=4
auto logits = tf_float.make(
{1, 2, 4},
{
0.1f,
0.2f,
0.3f,
0.4f, // First sequence position
0.5f,
0.6f,
0.9f,
0.8f // Second sequence position (last)
});

// Call logits_to_token with temperature 0 (deterministic)
int32_t token = runner_->logits_to_token(logits, 0.0f);

// Should use the last sequence position and return argmax (index 2)
EXPECT_EQ(token, 2);
}

// Test logits_to_token() method with Half tensor
TEST_F(TextDecoderRunnerTest, LogitsToTokenHalf) {
TensorFactory<executorch::aten::ScalarType::Half> tf_half;
auto logits = tf_half.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});

// Call logits_to_token with temperature 0 (deterministic)
int32_t token = runner_->logits_to_token(logits, 0.0f);

// With temperature 0, should return the argmax (index 2)
EXPECT_EQ(token, 2);
}

// Test logits_to_token() method with BFloat16 tensor
TEST_F(TextDecoderRunnerTest, LogitsToTokenBFloat16) {
TensorFactory<executorch::aten::ScalarType::BFloat16> tf_bfloat16;
auto logits = tf_bfloat16.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});

// Call logits_to_token with temperature 0 (deterministic)
int32_t token = runner_->logits_to_token(logits, 0.0f);

// With temperature 0, should return the argmax (index 2)
EXPECT_EQ(token, 2);
}

// Test logits_to_token() method with non-zero temperature
TEST_F(TextDecoderRunnerTest, LogitsToTokenWithTemperature) {
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});

// Call logits_to_token with temperature > 0 (stochastic)
int32_t token = runner_->logits_to_token(logits, 1.0f);

// With temperature > 0, result should be within valid range
EXPECT_GE(token, 0);
EXPECT_LT(token, 4);
}

// Test step() method with all available PTE models
TEST_F(TextDecoderRunnerTest, StepWithAllModels) {
// List of all environment variables for PTE models
std::vector<std::pair<std::string, const char*>> env_vars = {
{"KVCACHE_CACHE_POS", "KVCACHE_CACHE_POS"},
{"KVCACHE_INPUT_POS", "KVCACHE_INPUT_POS"},
{"NO_KVCACHE", "NO_KVCACHE"}};

// Check if any environment variables are set up front
bool any_env_set = false;
for (const auto& [model_name, env_var] : env_vars) {
if (std::getenv(env_var)) {
any_env_set = true;
break;
}
}

// Skip test if no environment variables are set
if (!any_env_set) {
GTEST_SKIP() << "No PTE model environment variables were set";
}

bool any_model_tested = false;

// Loop through all available models
for (const auto& [model_name, env_var] : env_vars) {
const char* model_path = std::getenv(env_var);
if (!model_path) {
continue; // Skip if environment variable not set
}

SCOPED_TRACE(
"Testing model: " + model_name + " from " + std::string(model_path));

// Load the model
auto module = std::make_unique<Module>(model_path);
auto load_result = module->load();
if (load_result != Error::Ok) {
ADD_FAILURE() << "Failed to load model " << model_name << " from "
<< model_path << " with error: " << (int)load_result;
continue;
}

// Create TextDecoderRunner
TextDecoderRunner runner(module.get());
auto runner_load_result = runner.load();
ASSERT_EQ(runner_load_result, Error::Ok)
<< "Failed to load runner for " << model_name;

// Verify method is loaded
EXPECT_TRUE(runner.is_method_loaded())
<< "Method not loaded for " << model_name;

// Create input tensor pointer

TensorFactory<executorch::aten::ScalarType::Long> tf_long;
auto input_tokens_ =
tf_long.make({1, 3}, {50, 7, 11}); // Single token input

auto input_ptr = std::make_shared<executorch::aten::Tensor>(input_tokens_);
int64_t start_pos = 0;

// Call step() and verify result is ok
auto result = runner.step(input_ptr, start_pos);
ASSERT_TRUE(result.ok()) << "step() failed for " << model_name
<< " with error: " << (int)result.error();

// Verify output tensor is valid
auto output_tensor = result.get();
EXPECT_GT(output_tensor.numel(), 0)
<< "Output tensor empty for " << model_name;

// Test logits_to_token works
int32_t token = runner.logits_to_token(output_tensor, 0.0f);
EXPECT_GE(token, 0) << "Invalid token for " << model_name;

any_model_tested = true;
}

// This should not happen since we checked environment variables up front
ASSERT_TRUE(any_model_tested)
<< "No models were tested despite environment variables being set";
}
7 changes: 3 additions & 4 deletions extension/llm/runner/test/test_text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ class MockModule : public ::executorch::extension::Module {

class MockTextDecoderRunner : public TextDecoderRunner {
public:
MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {}
MockTextDecoderRunner() : TextDecoderRunner(nullptr) {}
MOCK_METHOD(
Result<executorch::aten::Tensor>,
step,
(executorch::extension::TensorPtr&, executorch::extension::TensorPtr&),
(executorch::extension::TensorPtr&, int64_t),
());
MOCK_METHOD(bool, is_method_loaded, (), ());
MOCK_METHOD(Result<uint64_t>, prefill, (std::vector<uint64_t>&, int64_t), ());
Expand Down Expand Up @@ -134,8 +134,7 @@ class RunnerTest : public Test {
std::unique_ptr<MockTextDecoderRunner> createMockTextDecoderRunner() {
auto text_decoder_runner = std::make_unique<MockTextDecoderRunner>();
ON_CALL(*text_decoder_runner, step)
.WillByDefault([&](executorch::extension::TensorPtr&,
executorch::extension::TensorPtr&) {
.WillByDefault([&](executorch::extension::TensorPtr&, int64_t) {
return Result<executorch::aten::Tensor>(tensor);
});
ON_CALL(*text_decoder_runner, is_method_loaded())
Expand Down
7 changes: 3 additions & 4 deletions extension/llm/runner/test/test_text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ using executorch::runtime::testing::TensorFactory;
// Mock class for TextDecoderRunner
class MockTextDecoderRunner : public TextDecoderRunner {
public:
MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {}
MockTextDecoderRunner() : TextDecoderRunner(nullptr) {}
MOCK_METHOD(
Result<executorch::aten::Tensor>,
step,
(executorch::extension::TensorPtr&, executorch::extension::TensorPtr&),
(executorch::extension::TensorPtr&, int64_t),
());
MOCK_METHOD(bool, is_method_loaded, (), ());
MOCK_METHOD(Result<uint64_t>, prefill, (std::vector<uint64_t>&, int64_t), ());
Expand All @@ -44,8 +44,7 @@ class TextPrefillerTest : public Test {
ON_CALL(text_decoder_runner_, is_method_loaded())
.WillByDefault(Return(true));
ON_CALL(text_decoder_runner_, step)
.WillByDefault([&](executorch::extension::TensorPtr&,
executorch::extension::TensorPtr&) {
.WillByDefault([&](executorch::extension::TensorPtr&, int64_t) {
return Result<executorch::aten::Tensor>(tensor);
});
}
Expand Down
46 changes: 41 additions & 5 deletions extension/llm/runner/text_decoder_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// Given inputs, run a text decoder and return logits.

#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/kernels/portable/cpu/util/arange_util.h>

#include <ctime>

Expand All @@ -21,18 +22,53 @@ namespace llm {
// NOTE: we observed ~2x loading performance increase on iPhone 15
// and a ~5% improvement on Galaxy S22 by switching to
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
TextDecoderRunner::TextDecoderRunner(Module* module, bool use_kv_cache)
: module_(module), use_kv_cache_(use_kv_cache) {}
TextDecoderRunner::TextDecoderRunner(Module* module) : module_(module) {}

// This function is functional, meaning it shouldn't modify any state of the
// input. It should be safe to call multiple times with the same inputs. The
// outer loop (call site) is responsible for managing state.
::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
TensorPtr& tokens,
TensorPtr& start_pos) {
int64_t start_pos) {
// ET_LOG(Info, "Input token %" PRIu64, input_token);
if (use_kv_cache_) {
auto outputs_res = module_->forward({tokens, start_pos});
auto method_meta = ET_UNWRAP(module_->method_meta("forward"));
// If only 1 input, we are not using kv cache
bool use_kv_cache = method_meta.num_inputs() > 1;

if (use_kv_cache) {
// Size of the second argument. This could be either input_pos or
// cache_positions

// Check if we are using cache positions instead of input pos.
auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1));
// For input_pos, numel is 1, for cache_positions, numel is max_seq_len
auto sizes = second_input_info.sizes();
// Assuming 1D tensor
ET_CHECK_OR_RETURN_ERROR(
sizes.size() == 1,
InvalidProgram,
"The second input tensor is not 1D tensor. Got dimension (%zu)",
sizes.size());
auto numel = sizes[0];
std::vector<::executorch::aten::SizesType> sizes_vec = {numel};

// Assuming the last dimension is the one with the variable token length,
// for example [1, S] or [1, 1, S]
sizes_vec[sizes_vec.size() - 1] = numel;
TensorPtr start_pos_tensor;
if (numel > 1) {
// Assuming model is exported with cache_positions, create a tensor with
// the same size as cache_positions
start_pos_tensor = empty(sizes_vec, ::executorch::aten::ScalarType::Long);
torch::executor::native::arange_out_impl(
start_pos, start_pos + numel, 1.0, *start_pos_tensor);
} else {
// Assuming model is exported with input_pos, create a tensor with size 1
start_pos_tensor = from_blob(
&start_pos, sizes_vec, ::executorch::aten::ScalarType::Long);
}
ET_LOG(Info, "Start pos tensor numel: %zu", start_pos_tensor->numel());
auto outputs_res = module_->forward({tokens, start_pos_tensor});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
Expand Down
Loading
Loading