Skip to content

Commit 6312ea6

Browse files
committed
Update on "[ExecuTorch][Llama] Change runner to enable chunked prefill"
This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/) [ghstack-poisoned]
1 parent ae93078 commit 6312ea6

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

extension/llm/runner/text_prefiller.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ TextPrefiller::TextPrefiller(
2424
: text_decoder_runner_(text_decoder_runner),
2525
use_kv_cache_(use_kv_cache),
2626
enable_parallel_prefill_(enable_parallel_prefill),
27-
max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {} // -1 because for some reason tracing results in this upperbound
27+
max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {
28+
} // -1 because for some reason tracing results in this upperbound
2829

2930
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3031
std::vector<uint64_t>& prompt_tokens,
@@ -33,33 +34,35 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3334
if (!text_decoder_runner_->is_method_loaded()) {
3435
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
3536
}
36-
37+
3738
// Check if we need to chunk the prompt tokens
3839
int32_t num_prompt_tokens = prompt_tokens.size();
39-
40+
4041
// If prompt tokens exceed max_seq_len_, we need to chunk them
4142
if (num_prompt_tokens > max_seq_len_) {
4243
uint64_t cur_token = 0;
4344
int num_tokens_to_process = 0;
44-
45+
4546
while (num_tokens_to_process < num_prompt_tokens) {
46-
auto num_tokens_to_prefill_with =
47-
std::min<int>(num_prompt_tokens - num_tokens_to_process, max_seq_len_);
48-
49-
std::vector<uint64_t> prompt_tokens_to_process(num_tokens_to_prefill_with);
47+
auto num_tokens_to_prefill_with = std::min<int>(
48+
num_prompt_tokens - num_tokens_to_process, max_seq_len_);
49+
50+
std::vector<uint64_t> prompt_tokens_to_process(
51+
num_tokens_to_prefill_with);
5052
std::copy(
5153
prompt_tokens.begin() + num_tokens_to_process,
52-
prompt_tokens.begin() + num_tokens_to_process + num_tokens_to_prefill_with,
54+
prompt_tokens.begin() + num_tokens_to_process +
55+
num_tokens_to_prefill_with,
5356
prompt_tokens_to_process.begin());
54-
57+
5558
// Process this chunk
5659
auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos);
5760
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
5861
cur_token = chunk_result.get();
59-
62+
6063
num_tokens_to_process += num_tokens_to_prefill_with;
6164
}
62-
65+
6366
return cur_token;
6467
} else {
6568
// If prompt tokens don't exceed max_seq_len_, process them directly

extension/llm/runner/text_prefiller.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class ET_EXPERIMENTAL TextPrefiller {
3939
/**
4040
* Helper method to prefill a chunk of tokens.
4141
* @param prompt_tokens The chunk of text prompt tokens to process.
42-
* @param start_pos The starting position in KV cache of the input in the LLM Module.
42+
* @param start_pos The starting position in KV cache of the input in the LLM
43+
* Module.
4344
* @return The next token of the LLM Module after prefilling this chunk.
4445
*/
4546
::executorch::runtime::Result<uint64_t> prefillChunk(

0 commit comments

Comments
 (0)