00001 #include "lm/search_hashed.hh"
00002
00003 #include "lm/binary_format.hh"
00004 #include "lm/blank.hh"
00005 #include "lm/lm_exception.hh"
00006 #include "lm/model.hh"
00007 #include "lm/read_arpa.hh"
00008 #include "lm/value.hh"
00009 #include "lm/vocab.hh"
00010
00011 #include "util/bit_packing.hh"
00012 #include "util/file_piece.hh"
00013
00014 #include <string>
00015
00016 namespace lm {
00017 namespace ngram {
00018
00019 class ProbingModel;
00020
00021 namespace {
00022
00023
00024 template <class Middle> class ActivateLowerMiddle {
00025 public:
00026 explicit ActivateLowerMiddle(Middle &middle) : modify_(middle) {}
00027
00028 void operator()(const WordIndex *vocab_ids, const unsigned int n) {
00029 uint64_t hash = static_cast<WordIndex>(vocab_ids[1]);
00030 for (const WordIndex *i = vocab_ids + 2; i < vocab_ids + n; ++i) {
00031 hash = detail::CombineWordHash(hash, *i);
00032 }
00033 typename Middle::MutableIterator i;
00034
00035 if (!modify_.UnsafeMutableFind(hash, i))
00036 UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram");
00037 SetExtension(i->value.backoff);
00038 }
00039
00040 private:
00041 Middle &modify_;
00042 };
00043
00044 template <class Weights> class ActivateUnigram {
00045 public:
00046 explicit ActivateUnigram(Weights *unigram) : modify_(unigram) {}
00047
00048 void operator()(const WordIndex *vocab_ids, const unsigned int ) {
00049
00050 SetExtension(modify_[vocab_ids[1]].backoff);
00051 }
00052
00053 private:
00054 Weights *modify_;
00055 };
00056
00057
00058 template <class Value> void FindLower(
00059 const std::vector<uint64_t> &keys,
00060 typename Value::Weights &unigram,
00061 std::vector<util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> > &middle,
00062 std::vector<typename Value::Weights *> &between) {
00063 typename util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash>::MutableIterator iter;
00064 typename Value::ProbingEntry entry;
00065
00066 entry.value.backoff = kNoExtensionBackoff;
00067
00068 for (int lower = keys.size() - 2; ; --lower) {
00069 if (lower == -1) {
00070 between.push_back(&unigram);
00071 return;
00072 }
00073 entry.key = keys[lower];
00074 bool found = middle[lower].FindOrInsert(entry, iter);
00075 between.push_back(&iter->value);
00076 if (found) return;
00077 }
00078 }
00079
00080
00081 template <class Added, class Build> void AdjustLower(
00082 const Added &added,
00083 const Build &build,
00084 std::vector<typename Build::Value::Weights *> &between,
00085 const unsigned int n,
00086 const std::vector<WordIndex> &vocab_ids,
00087 typename Build::Value::Weights *unigrams,
00088 std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle) {
00089 typedef typename Build::Value Value;
00090 if (between.size() == 1) {
00091 build.MarkExtends(*between.front(), added);
00092 return;
00093 }
00094 typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
00095 float prob = -fabs(between.back()->prob);
00096
00097 unsigned char basis = n - between.size();
00098 assert(basis != 0);
00099 typename Build::Value::Weights **change = &between.back();
00100
00101 --change;
00102 if (basis == 1) {
00103
00104 float &backoff = unigrams[vocab_ids[1]].backoff;
00105 SetExtension(backoff);
00106 prob += backoff;
00107 (*change)->prob = prob;
00108 build.SetRest(&*vocab_ids.begin(), 2, **change);
00109 basis = 2;
00110 --change;
00111 }
00112 uint64_t backoff_hash = static_cast<uint64_t>(vocab_ids[1]);
00113 for (unsigned char i = 2; i <= basis; ++i) {
00114 backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]);
00115 }
00116 for (; basis < n - 1; ++basis, --change) {
00117 typename Middle::MutableIterator gotit;
00118 if (middle[basis - 2].UnsafeMutableFind(backoff_hash, gotit)) {
00119 float &backoff = gotit->value.backoff;
00120 SetExtension(backoff);
00121 prob += backoff;
00122 }
00123 (*change)->prob = prob;
00124 build.SetRest(&*vocab_ids.begin(), basis + 1, **change);
00125 backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[basis+1]);
00126 }
00127
00128 typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());
00129 build.MarkExtends(**i, added);
00130 const typename Value::Weights *longer = *i;
00131
00132 for (++i; i != between.end(); ++i) {
00133 build.MarkExtends(**i, *longer);
00134 longer = *i;
00135 }
00136 }
00137
00138
00139 template <class Build> void MarkLower(
00140 const std::vector<uint64_t> &keys,
00141 const Build &build,
00142 typename Build::Value::Weights &unigram,
00143 std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle,
00144 int start_order,
00145 const typename Build::Value::Weights &longer) {
00146 if (start_order == 0) return;
00147
00148 for (int even_lower = start_order - 2 ; ; --even_lower) {
00149 if (even_lower == -1) {
00150 build.MarkExtends(unigram, longer);
00151 return;
00152 }
00153 if (!build.MarkExtends(
00154 middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value,
00155 longer)) return;
00156 }
00157 }
00158
00159 template <class Build, class Activate, class Store> void ReadNGrams(
00160 util::FilePiece &f,
00161 const unsigned int n,
00162 const size_t count,
00163 const ProbingVocabulary &vocab,
00164 const Build &build,
00165 typename Build::Value::Weights *unigrams,
00166 std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle,
00167 Activate activate,
00168 Store &store,
00169 PositiveProbWarn &warn) {
00170 typedef typename Build::Value Value;
00171 assert(n >= 2);
00172 ReadNGramHeader(f, n);
00173
00174
00175
00176 std::vector<WordIndex> vocab_ids(n);
00177 std::vector<uint64_t> keys(n-1);
00178 typename Store::Entry entry;
00179 std::vector<typename Value::Weights *> between;
00180 for (size_t i = 0; i < count; ++i) {
00181 ReadNGram(f, n, vocab, vocab_ids.rbegin(), entry.value, warn);
00182 build.SetRest(&*vocab_ids.begin(), n, entry.value);
00183
00184 keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]);
00185 for (unsigned int h = 1; h < n - 1; ++h) {
00186 keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
00187 }
00188
00189 util::SetSign(entry.value.prob);
00190 entry.key = keys[n-2];
00191
00192 store.Insert(entry);
00193 between.clear();
00194 FindLower<Value>(keys, unigrams[vocab_ids.front()], middle, between);
00195 AdjustLower<typename Store::Entry::Value, Build>(entry.value, build, between, n, vocab_ids, unigrams, middle);
00196 if (Build::kMarkEvenLower) MarkLower<Build>(keys, build, unigrams[vocab_ids.front()], middle, n - between.size() - 1, *between.back());
00197 activate(&*vocab_ids.begin(), n);
00198 }
00199
00200 store.FinishedInserting();
00201 }
00202
00203 }
00204 namespace detail {
00205
00206 template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
00207 unigram_ = Unigram(start, counts[0]);
00208 start += Unigram::Size(counts[0]);
00209 std::size_t allocated;
00210 middle_.clear();
00211 for (unsigned int n = 2; n < counts.size(); ++n) {
00212 allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
00213 middle_.push_back(Middle(start, allocated));
00214 start += allocated;
00215 }
00216 allocated = Longest::Size(counts.back(), config.probing_multiplier);
00217 longest_ = Longest(start, allocated);
00218 start += allocated;
00219 return start;
00220 }
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232 template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * , util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) {
00233 void *vocab_rebase;
00234 void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase);
00235 vocab.Relocate(vocab_rebase);
00236 SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config);
00237
00238 PositiveProbWarn warn(config.positive_log_probability);
00239 Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn);
00240 CheckSpecials(config, vocab);
00241 DispatchBuild(f, counts, config, vocab, warn);
00242 }
00243
00244 template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
00245 NoRestBuild build;
00246 ApplyBuild(f, counts, vocab, warn, build);
00247 }
00248
00249 template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
00250 switch (config.rest_function) {
00251 case Config::REST_MAX:
00252 {
00253 MaxRestBuild build;
00254 ApplyBuild(f, counts, vocab, warn, build);
00255 }
00256 break;
00257 case Config::REST_LOWER:
00258 {
00259 LowerRestBuild<ProbingModel> build(config, counts.size(), vocab);
00260 ApplyBuild(f, counts, vocab, warn, build);
00261 }
00262 break;
00263 }
00264 }
00265
00266 template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
00267 for (WordIndex i = 0; i < counts[0]; ++i) {
00268 build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]);
00269 }
00270
00271 try {
00272 if (counts.size() > 2) {
00273 ReadNGrams<Build, ActivateUnigram<typename Value::Weights>, Middle>(
00274 f, 2, counts[1], vocab, build, unigram_.Raw(), middle_, ActivateUnigram<typename Value::Weights>(unigram_.Raw()), middle_[0], warn);
00275 }
00276 for (unsigned int n = 3; n < counts.size(); ++n) {
00277 ReadNGrams<Build, ActivateLowerMiddle<Middle>, Middle>(
00278 f, n, counts[n-1], vocab, build, unigram_.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn);
00279 }
00280 if (counts.size() > 2) {
00281 ReadNGrams<Build, ActivateLowerMiddle<Middle>, Longest>(
00282 f, counts.size(), counts[counts.size() - 1], vocab, build, unigram_.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest_, warn);
00283 } else {
00284 ReadNGrams<Build, ActivateUnigram<typename Value::Weights>, Longest>(
00285 f, counts.size(), counts[counts.size() - 1], vocab, build, unigram_.Raw(), middle_, ActivateUnigram<typename Value::Weights>(unigram_.Raw()), longest_, warn);
00286 }
00287 } catch (util::ProbingSizeException &e) {
00288 UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n");
00289 }
00290 ReadEnd(f);
00291 }
00292
00293 template class HashedSearch<BackoffValue>;
00294 template class HashedSearch<RestValue>;
00295
00296 }
00297 }
00298 }