Skip to content

Commit da9f2af

Browse files
committed
[ExecuTorch][Llama] Change runner to enable chunked prefill
This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/) ghstack-source-id: 275212294 Pull Request resolved: #9785
1 parent 2aa7748 commit da9f2af

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

examples/models/llama/runner/runner.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <executorch/examples/models/llama/runner/runner.h>
1313

14+
#include <algorithm>
1415
#include <ctime>
1516

1617
#include <executorch/extension/llm/runner/util.h>
@@ -221,11 +222,11 @@ Error Runner::generate(
221222

222223
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
223224
ET_CHECK_MSG(
224-
num_prompt_tokens < metadata_.at(kMaxSeqLen),
225+
num_prompt_tokens < metadata_.at(kMaxContextLen),
225226
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
226227
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
227228
num_prompt_tokens,
228-
metadata_.at(kMaxSeqLen));
229+
metadata_.at(kMaxContextLen));
229230
ET_CHECK_MSG(
230231
num_prompt_tokens < seq_len,
231232
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
@@ -241,11 +242,26 @@ Error Runner::generate(
241242
wrapped_callback(prompt);
242243
}
243244
int64_t pos = 0;
244-
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
245+
uint64_t cur_token;
246+
int max_seq_len = metadata_.at(kMaxSeqLen) -
247+
1; // -1 because for some reason tracing results in this upperbound
248+
int num_tokens_to_process = 0;
249+
while (num_tokens_to_process < num_prompt_tokens) {
250+
auto num_tokens_to_prefill_with =
251+
std::min(num_prompt_tokens - num_tokens_to_process, max_seq_len);
252+
std::vector<uint64_t> prompt_tokens_to_process(num_tokens_to_prefill_with);
253+
std::copy(
254+
prompt_tokens.begin() + num_tokens_to_process,
255+
prompt_tokens.begin() + num_tokens_to_process + num_tokens_to_prefill_with,
256+
prompt_tokens_to_process.begin());
257+
auto prefill_res =
258+
text_prefiller_->prefill(prompt_tokens_to_process, pos);
259+
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
260+
cur_token = prefill_res.get();
261+
num_tokens_to_process += num_tokens_to_prefill_with;
262+
}
245263
stats_.first_token_ms = llm::time_in_ms();
246264
stats_.prompt_eval_end_ms = llm::time_in_ms();
247-
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
248-
uint64_t cur_token = prefill_res.get();
249265

250266
// print the first token from prefill. No prev_token so use cur_token for it.
251267
wrapped_callback(

0 commit comments

Comments
 (0)