diff --git a/include/pytorch/tokenizers/bpe_tokenizer_base.h b/include/pytorch/tokenizers/bpe_tokenizer_base.h index 16c0456..1753d86 100644 --- a/include/pytorch/tokenizers/bpe_tokenizer_base.h +++ b/include/pytorch/tokenizers/bpe_tokenizer_base.h @@ -25,6 +25,8 @@ #include #include +#include "re2/re2.h" + namespace tokenizers { namespace detail { @@ -104,6 +106,25 @@ static Result buildTokenMap( return buildTokenMap(std::move(pairs)); } +static Result> build_special_token_regex( + const TokenMap& special_token_map) { + std::string special_pattern; + const std::size_t count = special_token_map.size(); + + for (std::size_t i = 0; i < count; ++i) { + const auto& [token, _] = special_token_map.getElement(i); + if (!special_pattern.empty()) { + special_pattern += "|"; + } + special_pattern += re2::RE2::QuoteMeta(std::string(token)); + } + + if (special_pattern.empty()) { + return static_cast>(nullptr); + } + return create_regex(special_pattern); +} + class BPETokenizerBase : public Tokenizer { public: Result> diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index 8f06889..a04005b 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -69,6 +69,11 @@ Error HFTokenizer::load(const std::string& path) { special_tokens, [](const auto& it) -> std::string { return it.at("content"); }, [](const auto& it) -> std::uint64_t { return it.at("id"); })); + + // Create special token regex to help later with encoding. + special_token_regex_ = TK_UNWRAP(detail::build_special_token_regex(special_token_map)); + + // Store for future use. special_token_map_.emplace(std::move(special_token_map)); } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse special tokens: %s\n", e.what()); diff --git a/src/tiktoken.cpp b/src/tiktoken.cpp index 57b3cbc..9a3565d 100644 --- a/src/tiktoken.cpp +++ b/src/tiktoken.cpp @@ -32,7 +32,6 @@ #include #include #include -#include "re2/re2.h" namespace tokenizers { @@ -47,20 +46,6 @@ static Result> _create_regex( return create_regex(pattern); } -static Result> _build_special_token_regex( - const std::vector>& special_encoder) { - std::string special_pattern; - for (const auto& ele : special_encoder) { - if (!special_pattern.empty()) { - special_pattern += "|"; - } - special_pattern += re2::RE2::QuoteMeta(ele.first); - } - if (special_pattern.empty()) { - return static_cast>(nullptr); - } - return _create_regex(special_pattern); -} static Result> _parse( const std::string& line) { @@ -153,7 +138,7 @@ Error Tiktoken::load(const std::string& path) { _regex = TK_UNWRAP(_create_regex(_pattern)); special_token_regex_ = - TK_UNWRAP(_build_special_token_regex(special_token_map)); + TK_UNWRAP(detail::build_special_token_regex(TokenMap(special_token_map))); // initialize vocab_size, bos_tok, eos_tok vocab_size_ = token_map_->size() + special_token_map_->size();