11
11
12
12
#include < executorch/examples/models/llama/runner/runner.h>
13
13
14
+ #include < algorithm>
14
15
#include < ctime>
15
16
16
17
#include < executorch/extension/llm/runner/util.h>
@@ -221,11 +222,11 @@ Error Runner::generate(
221
222
222
223
ET_CHECK_MSG (num_prompt_tokens >= 1 , " Expected at least 1 prompt token" );
223
224
ET_CHECK_MSG (
224
- num_prompt_tokens < metadata_.at (kMaxSeqLen ),
225
+ num_prompt_tokens < metadata_.at (kMaxContextLen ),
225
226
" num_prompt_tokens %d >= max_seq_len_ %" PRId64
226
227
" , Max seq length exceeded - please increase max seq len value in .../llama2/model.py" ,
227
228
num_prompt_tokens,
228
- metadata_.at (kMaxSeqLen ));
229
+ metadata_.at (kMaxContextLen ));
229
230
ET_CHECK_MSG (
230
231
num_prompt_tokens < seq_len,
231
232
" num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()" ,
@@ -241,11 +242,26 @@ Error Runner::generate(
241
242
wrapped_callback (prompt);
242
243
}
243
244
int64_t pos = 0 ;
244
- auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
245
+ uint64_t cur_token;
246
+ int max_seq_len = metadata_.at (kMaxSeqLen ) -
247
+ 1 ; // -1 because for some reason tracing results in this upperbound
248
+ int num_tokens_to_process = 0 ;
249
+ while (num_tokens_to_process < num_prompt_tokens) {
250
+ auto num_tokens_to_prefill_with =
251
+ std::min (num_prompt_tokens - num_tokens_to_process, max_seq_len);
252
+ std::vector<uint64_t > prompt_tokens_to_process (num_tokens_to_prefill_with);
253
+ std::copy (
254
+ prompt_tokens.begin () + num_tokens_to_process,
255
+ prompt_tokens.begin () + num_tokens_to_process + num_tokens_to_prefill_with,
256
+ prompt_tokens_to_process.begin ());
257
+ auto prefill_res =
258
+ text_prefiller_->prefill (prompt_tokens_to_process, pos);
259
+ ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
260
+ cur_token = prefill_res.get ();
261
+ num_tokens_to_process += num_tokens_to_prefill_with;
262
+ }
245
263
stats_.first_token_ms = llm::time_in_ms ();
246
264
stats_.prompt_eval_end_ms = llm::time_in_ms ();
247
- ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
248
- uint64_t cur_token = prefill_res.get ();
249
265
250
266
// print the first token from prefill. No prev_token so use cur_token for it.
251
267
wrapped_callback (
0 commit comments