From 606b1b214f022a76e048f1ca67d0f6a1f0ecaf8c Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 30 Apr 2025 23:12:50 -0700 Subject: [PATCH 1/2] Handle null bos and eos token --- src/hf_tokenizer.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index 44b68a5..8f06889 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -142,8 +142,15 @@ Error HFTokenizer::load(const std::string& path) { // Pull out the token strings try { - const std::string bos_token = parsed_config_json.at("bos_token"); - const std::string eos_token = parsed_config_json.at("eos_token"); + const std::string bos_token = + parsed_config_json.contains("bos_token") && !parsed_config_json["bos_token"].is_null() + ? parsed_config_json["bos_token"].get() + : ""; + + const std::string eos_token = + parsed_config_json.contains("eos_token") && !parsed_config_json["eos_token"].is_null() + ? parsed_config_json["eos_token"].get() + : ""; const auto bos_res = special_token_map_->tryGetInteger(bos_token); const auto eos_res = special_token_map_->tryGetInteger(eos_token); if (!bos_res) { From b93cd82efc5f7331647ddf594194ebdbd07a1a65 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 1 May 2025 08:37:42 -0700 Subject: [PATCH 2/2] Fix hf tokenizer handling of special tokens --- .../pytorch/tokenizers/bpe_tokenizer_base.h | 25 +++++++++++++++++++ src/hf_tokenizer.cpp | 5 ++++ src/tiktoken.cpp | 17 +------------ 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/include/pytorch/tokenizers/bpe_tokenizer_base.h b/include/pytorch/tokenizers/bpe_tokenizer_base.h index 16c0456..06bcc21 100644 --- a/include/pytorch/tokenizers/bpe_tokenizer_base.h +++ b/include/pytorch/tokenizers/bpe_tokenizer_base.h @@ -25,6 +25,8 @@ #include #include +#include "re2/re2.h" + namespace tokenizers { namespace detail { @@ -104,6 +106,29 @@ static Result buildTokenMap( return buildTokenMap(std::move(pairs)); } +static Result> build_special_token_regex( + const TokenMap& special_token_map) { + std::string special_pattern; + const std::size_t count = special_token_map.size(); + + std::cout << "iterating" << std::endl; + for (std::size_t i = 0; i < count; ++i) { + std::cout << "i: " << i << "/" << count << std::endl; + const auto& [token, _] = special_token_map.getElement(i); + std::cout << "token: " << token << std::endl; + if (!special_pattern.empty()) { + special_pattern += "|"; + } + special_pattern += re2::RE2::QuoteMeta(std::string(token)); + } + std::cout << "special pattern: " << special_pattern << std::endl; + + if (special_pattern.empty()) { + return static_cast>(nullptr); + } + return create_regex(special_pattern); +} + class BPETokenizerBase : public Tokenizer { public: Result> diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index 8f06889..a04005b 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -69,6 +69,11 @@ Error HFTokenizer::load(const std::string& path) { special_tokens, [](const auto& it) -> std::string { return it.at("content"); }, [](const auto& it) -> std::uint64_t { return it.at("id"); })); + + // Create special token regex to help later with encoding. + special_token_regex_ = TK_UNWRAP(detail::build_special_token_regex(special_token_map)); + + // Store for future use. special_token_map_.emplace(std::move(special_token_map)); } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse special tokens: %s\n", e.what()); diff --git a/src/tiktoken.cpp b/src/tiktoken.cpp index 57b3cbc..9a3565d 100644 --- a/src/tiktoken.cpp +++ b/src/tiktoken.cpp @@ -32,7 +32,6 @@ #include #include #include -#include "re2/re2.h" namespace tokenizers { @@ -47,20 +46,6 @@ static Result> _create_regex( return create_regex(pattern); } -static Result> _build_special_token_regex( - const std::vector>& special_encoder) { - std::string special_pattern; - for (const auto& ele : special_encoder) { - if (!special_pattern.empty()) { - special_pattern += "|"; - } - special_pattern += re2::RE2::QuoteMeta(ele.first); - } - if (special_pattern.empty()) { - return static_cast>(nullptr); - } - return _create_regex(special_pattern); -} static Result> _parse( const std::string& line) { @@ -153,7 +138,7 @@ Error Tiktoken::load(const std::string& path) { _regex = TK_UNWRAP(_create_regex(_pattern)); special_token_regex_ = - TK_UNWRAP(_build_special_token_regex(special_token_map)); + TK_UNWRAP(detail::build_special_token_regex(TokenMap(special_token_map))); // initialize vocab_size, bos_tok, eos_tok vocab_size_ = token_map_->size() + special_token_map_->size();