@@ -24,7 +24,8 @@ TextPrefiller::TextPrefiller(
24
24
: text_decoder_runner_(text_decoder_runner),
25
25
use_kv_cache_ (use_kv_cache),
26
26
enable_parallel_prefill_(enable_parallel_prefill),
27
- max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127 ) {} // -1 because for some reason tracing results in this upperbound
27
+ max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127 ) {
28
+ } // -1 because for some reason tracing results in this upperbound
28
29
29
30
::executorch::runtime::Result<uint64_t > TextPrefiller::prefill (
30
31
std::vector<uint64_t >& prompt_tokens,
@@ -33,33 +34,35 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
33
34
if (!text_decoder_runner_->is_method_loaded ()) {
34
35
ET_CHECK_OK_OR_RETURN_ERROR (text_decoder_runner_->load ());
35
36
}
36
-
37
+
37
38
// Check if we need to chunk the prompt tokens
38
39
int32_t num_prompt_tokens = prompt_tokens.size ();
39
-
40
+
40
41
// If prompt tokens exceed max_seq_len_, we need to chunk them
41
42
if (num_prompt_tokens > max_seq_len_) {
42
43
uint64_t cur_token = 0 ;
43
44
int num_tokens_to_process = 0 ;
44
-
45
+
45
46
while (num_tokens_to_process < num_prompt_tokens) {
46
- auto num_tokens_to_prefill_with =
47
- std::min<int >(num_prompt_tokens - num_tokens_to_process, max_seq_len_);
48
-
49
- std::vector<uint64_t > prompt_tokens_to_process (num_tokens_to_prefill_with);
47
+ auto num_tokens_to_prefill_with = std::min<int >(
48
+ num_prompt_tokens - num_tokens_to_process, max_seq_len_);
49
+
50
+ std::vector<uint64_t > prompt_tokens_to_process (
51
+ num_tokens_to_prefill_with);
50
52
std::copy (
51
53
prompt_tokens.begin () + num_tokens_to_process,
52
- prompt_tokens.begin () + num_tokens_to_process + num_tokens_to_prefill_with,
54
+ prompt_tokens.begin () + num_tokens_to_process +
55
+ num_tokens_to_prefill_with,
53
56
prompt_tokens_to_process.begin ());
54
-
57
+
55
58
// Process this chunk
56
59
auto chunk_result = prefillChunk (prompt_tokens_to_process, start_pos);
57
60
ET_CHECK_OK_OR_RETURN_ERROR (chunk_result.error ());
58
61
cur_token = chunk_result.get ();
59
-
62
+
60
63
num_tokens_to_process += num_tokens_to_prefill_with;
61
64
}
62
-
65
+
63
66
return cur_token;
64
67
} else {
65
68
// If prompt tokens don't exceed max_seq_len_, process them directly
0 commit comments