diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 1831d46f39..970573713b 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -1,4 +1,5 @@ #include // @manual +#include #include #include #include @@ -8,7 +9,8 @@ namespace torchtext { Vocab::Vocab(const StringList &tokens, const std::string &unk_token) - : stoi_(MAX_VOCAB_SIZE, -1), unk_token_(std::move(unk_token)) { + : unk_token_(std::move(unk_token)) { + std::fill(stoi_.begin(), stoi_.end(), -1); for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates auto token_position = diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index d915c7de27..b9b17a3263 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -12,7 +12,7 @@ typedef std::tuple, std::vector, struct Vocab : torch::CustomClassHolder { static const int32_t MAX_VOCAB_SIZE = 30000000; int64_t unk_index_; - std::vector stoi_; + std::array stoi_; const std::string version_str_ = "0.0.1"; StringList itos_; std::string unk_token_; @@ -44,7 +44,7 @@ struct Vocab : torch::CustomClassHolder { uint32_t _find(const c10::string_view &w) const { uint32_t stoi_size = stoi_.size(); uint32_t id = _hash(w) % stoi_size; - while (stoi_[id] != -1 && itos_[stoi_[id]]!= w) { + while (stoi_[id] != -1 && itos_[stoi_[id]] != w) { id = (id + 1) % stoi_size; } return id;