From 949b0ad2647c29669c0ff82265631f71f393571c Mon Sep 17 00:00:00 2001 From: Jack <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:39:25 -0700 Subject: [PATCH] Reland #66 and #67 (#74) Summary: Reland https://github.com/pytorch-labs/tokenizers/issues/66 and https://github.com/pytorch-labs/tokenizers/issues/67 with unbypassable arc lint fixes Reviewed By: kirklandsign Differential Revision: D74693197 Pulled By: jackzhxng --- .../pytorch/tokenizers/bpe_tokenizer_base.h | 21 +++++++++++++++++++ src/hf_tokenizer.cpp | 17 +++++++++++++-- src/tiktoken.cpp | 18 +--------------- targets.bzl | 3 +++ 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/include/pytorch/tokenizers/bpe_tokenizer_base.h b/include/pytorch/tokenizers/bpe_tokenizer_base.h index 16c0456..97542bf 100644 --- a/include/pytorch/tokenizers/bpe_tokenizer_base.h +++ b/include/pytorch/tokenizers/bpe_tokenizer_base.h @@ -25,6 +25,8 @@ #include #include +#include "re2/re2.h" + namespace tokenizers { namespace detail { @@ -104,6 +106,25 @@ static Result buildTokenMap( return buildTokenMap(std::move(pairs)); } +inline Result> build_special_token_regex( + const TokenMap& special_token_map) { + std::string special_pattern; + const std::size_t count = special_token_map.size(); + + for (std::size_t i = 0; i < count; ++i) { + const auto& [token, _] = special_token_map.getElement(i); + if (!special_pattern.empty()) { + special_pattern += "|"; + } + special_pattern += re2::RE2::QuoteMeta(std::string(token)); + } + + if (special_pattern.empty()) { + return static_cast>(nullptr); + } + return create_regex(special_pattern); +} + class BPETokenizerBase : public Tokenizer { public: Result> diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index 44b68a5..fa62264 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -69,6 +69,12 @@ Error HFTokenizer::load(const std::string& path) { special_tokens, [](const auto& it) -> std::string { return it.at("content"); }, [](const auto& it) -> std::uint64_t { return it.at("id"); })); + + // Create special token regex to help later with encoding. + special_token_regex_ = + TK_UNWRAP(detail::build_special_token_regex(special_token_map)); + + // Store for future use. special_token_map_.emplace(std::move(special_token_map)); } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse special tokens: %s\n", e.what()); @@ -142,8 +148,15 @@ Error HFTokenizer::load(const std::string& path) { // Pull out the token strings try { - const std::string bos_token = parsed_config_json.at("bos_token"); - const std::string eos_token = parsed_config_json.at("eos_token"); + const std::string bos_token = parsed_config_json.contains("bos_token") && + !parsed_config_json["bos_token"].is_null() + ? parsed_config_json["bos_token"].get() + : ""; + + const std::string eos_token = parsed_config_json.contains("eos_token") && + !parsed_config_json["eos_token"].is_null() + ? parsed_config_json["eos_token"].get() + : ""; const auto bos_res = special_token_map_->tryGetInteger(bos_token); const auto eos_res = special_token_map_->tryGetInteger(eos_token); if (!bos_res) { diff --git a/src/tiktoken.cpp b/src/tiktoken.cpp index 57b3cbc..c112221 100644 --- a/src/tiktoken.cpp +++ b/src/tiktoken.cpp @@ -32,7 +32,6 @@ #include #include #include -#include "re2/re2.h" namespace tokenizers { @@ -47,21 +46,6 @@ static Result> _create_regex( return create_regex(pattern); } -static Result> _build_special_token_regex( - const std::vector>& special_encoder) { - std::string special_pattern; - for (const auto& ele : special_encoder) { - if (!special_pattern.empty()) { - special_pattern += "|"; - } - special_pattern += re2::RE2::QuoteMeta(ele.first); - } - if (special_pattern.empty()) { - return static_cast>(nullptr); - } - return _create_regex(special_pattern); -} - static Result> _parse( const std::string& line) { // Tiktoken format @@ -153,7 +137,7 @@ Error Tiktoken::load(const std::string& path) { _regex = TK_UNWRAP(_create_regex(_pattern)); special_token_regex_ = - TK_UNWRAP(_build_special_token_regex(special_token_map)); + TK_UNWRAP(detail::build_special_token_regex(TokenMap(special_token_map))); // initialize vocab_size, bos_tok, eos_tok vocab_size_ = token_map_->size() + special_token_map_->size(); diff --git a/targets.bzl b/targets.bzl index 5271317..2a2bf5c 100644 --- a/targets.bzl +++ b/targets.bzl @@ -77,6 +77,9 @@ def define_common_targets(): exported_deps = [ ":headers", ], + exported_external_deps = [ + "re2", + ], visibility = [ "//pytorch/tokenizers/...", ],