00001 #ifndef LM_VOCAB_H
00002 #define LM_VOCAB_H
00003
00004 #include "lm/enumerate_vocab.hh"
00005 #include "lm/lm_exception.hh"
00006 #include "lm/virtual_interface.hh"
00007 #include "util/file_stream.hh"
00008 #include "util/murmur_hash.hh"
00009 #include "util/pool.hh"
00010 #include "util/probing_hash_table.hh"
00011 #include "util/sorted_uniform.hh"
00012 #include "util/string_piece.hh"
00013
00014 #include <limits>
00015 #include <string>
00016 #include <vector>
00017
00018 namespace lm {
00019 struct ProbBackoff;
00020 class EnumerateVocab;
00021
00022 namespace ngram {
00023 struct Config;
00024
00025 namespace detail {
00026 uint64_t HashForVocab(const char *str, std::size_t len);
00027 inline uint64_t HashForVocab(const StringPiece &str) {
00028 return HashForVocab(str.data(), str.length());
00029 }
00030 struct ProbingVocabularyHeader;
00031 }
00032
00033
00034
00035 class ImmediateWriteWordsWrapper : public EnumerateVocab {
00036 public:
00037 ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start);
00038
00039 void Add(WordIndex index, const StringPiece &str) {
00040 stream_ << str << '\0';
00041 if (inner_) inner_->Add(index, str);
00042 }
00043
00044 private:
00045 EnumerateVocab *inner_;
00046
00047 util::FileStream stream_;
00048 };
00049
00050
00051 class WriteWordsWrapper : public EnumerateVocab {
00052 public:
00053 WriteWordsWrapper(EnumerateVocab *inner);
00054
00055 void Add(WordIndex index, const StringPiece &str);
00056
00057 const std::string &Buffer() const { return buffer_; }
00058 void Write(int fd, uint64_t start);
00059
00060 private:
00061 EnumerateVocab *inner_;
00062
00063 std::string buffer_;
00064 };
00065
00066
00067 class SortedVocabulary : public base::Vocabulary {
00068 public:
00069 SortedVocabulary();
00070
00071 WordIndex Index(const StringPiece &str) const {
00072 const uint64_t *found;
00073 if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>(
00074 util::IdentityAccessor<uint64_t>(),
00075 begin_ - 1, 0,
00076 end_, std::numeric_limits<uint64_t>::max(),
00077 detail::HashForVocab(str), found)) {
00078 return found - begin_ + 1;
00079 } else {
00080 return 0;
00081 }
00082 }
00083
00084
00085 static uint64_t Size(uint64_t entries, const Config &config);
00086
00087
00088
00089
00090
00091 static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping);
00092
00093
00094 WordIndex Bound() const { return bound_; }
00095
00096
00097 void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
00098
00099 void Relocate(void *new_start);
00100
00101 void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
00102
00103
00104 WordIndex Insert(const StringPiece &str);
00105
00106 void FinishedLoading(ProbBackoff *reorder_vocab);
00107
00108
00109 std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
00110
00111 bool SawUnk() const { return saw_unk_; }
00112
00113 void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
00114
00115 uint64_t *&EndHack() { return end_; }
00116
00117 void Populated();
00118
00119 private:
00120 template <class T> void GenericFinished(T *reorder);
00121
00122 uint64_t *begin_, *end_;
00123
00124 WordIndex bound_;
00125
00126 bool saw_unk_;
00127
00128 EnumerateVocab *enumerate_;
00129
00130
00131 util::Pool string_backing_;
00132
00133 std::vector<StringPiece> strings_to_enumerate_;
00134 };
00135
00136 #pragma pack(push)
00137 #pragma pack(4)
00138 struct ProbingVocabularyEntry {
00139 uint64_t key;
00140 WordIndex value;
00141
00142 typedef uint64_t Key;
00143 uint64_t GetKey() const { return key; }
00144 void SetKey(uint64_t to) { key = to; }
00145
00146 static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) {
00147 ProbingVocabularyEntry ret;
00148 ret.key = key;
00149 ret.value = value;
00150 return ret;
00151 }
00152 };
00153 #pragma pack(pop)
00154
00155
00156 class ProbingVocabulary : public base::Vocabulary {
00157 public:
00158 ProbingVocabulary();
00159
00160 WordIndex Index(const StringPiece &str) const {
00161 Lookup::ConstIterator i;
00162 return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
00163 }
00164
00165 static uint64_t Size(uint64_t entries, float probing_multiplier);
00166
00167 static uint64_t Size(uint64_t entries, const Config &config);
00168
00169
00170 WordIndex Bound() const { return bound_; }
00171
00172
00173 void SetupMemory(void *start, std::size_t allocated);
00174 void SetupMemory(void *start, std::size_t allocated, std::size_t , const Config &) {
00175 SetupMemory(start, allocated);
00176 }
00177
00178 void Relocate(void *new_start);
00179
00180 void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
00181
00182 WordIndex Insert(const StringPiece &str);
00183
00184 template <class Weights> void FinishedLoading(Weights * ) {
00185 InternalFinishedLoading();
00186 }
00187
00188 std::size_t UnkCountChangePadding() const { return 0; }
00189
00190 bool SawUnk() const { return saw_unk_; }
00191
00192 void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
00193
00194 private:
00195 void InternalFinishedLoading();
00196
00197 typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;
00198
00199 Lookup lookup_;
00200
00201 WordIndex bound_;
00202
00203 bool saw_unk_;
00204
00205 EnumerateVocab *enumerate_;
00206
00207 detail::ProbingVocabularyHeader *header_;
00208 };
00209
00210 void MissingUnknown(const Config &config) throw(SpecialWordMissingException);
00211 void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException);
00212
00213 template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) {
00214 if (!vocab.SawUnk()) MissingUnknown(config);
00215 if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
00216 if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
00217 }
00218
00219 class WriteUniqueWords {
00220 public:
00221 explicit WriteUniqueWords(int fd) : word_list_(fd) {}
00222
00223 void operator()(const StringPiece &word) {
00224 word_list_ << word << '\0';
00225 }
00226
00227 private:
00228 util::FileStream word_list_;
00229 };
00230
00231 class NoOpUniqueWords {
00232 public:
00233 NoOpUniqueWords() {}
00234 void operator()(const StringPiece &word) {}
00235 };
00236
00237 template <class NewWordAction = NoOpUniqueWords> class GrowableVocab {
00238 public:
00239 static std::size_t MemUsage(WordIndex content) {
00240 return Lookup::MemUsage(content > 2 ? content : 2);
00241 }
00242
00243
00244 template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction())
00245 : lookup_(initial_size), new_word_(new_word_construct) {
00246 FindOrInsert("<unk>");
00247 FindOrInsert("<s>");
00248 FindOrInsert("</s>");
00249 }
00250
00251 WordIndex Index(const StringPiece &str) const {
00252 Lookup::ConstIterator i;
00253 return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
00254 }
00255
00256 WordIndex FindOrInsert(const StringPiece &word) {
00257 ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size());
00258 Lookup::MutableIterator it;
00259 if (!lookup_.FindOrInsert(entry, it)) {
00260 new_word_(word);
00261 UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh");
00262 }
00263 return it->value;
00264 }
00265
00266 WordIndex Size() const { return lookup_.Size(); }
00267
00268 private:
00269 typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup;
00270
00271 Lookup lookup_;
00272
00273 NewWordAction new_word_;
00274 };
00275
00276 }
00277 }
00278
00279 #endif // LM_VOCAB_H