00001 #ifndef LM_SEARCH_HASHED_H
00002 #define LM_SEARCH_HASHED_H
00003
00004 #include "lm/model_type.hh"
00005 #include "lm/config.hh"
00006 #include "lm/read_arpa.hh"
00007 #include "lm/return.hh"
00008 #include "lm/weights.hh"
00009
00010 #include "util/bit_packing.hh"
00011 #include "util/probing_hash_table.hh"
00012
00013 #include <algorithm>
00014 #include <iostream>
00015 #include <vector>
00016
00017 namespace util { class FilePiece; }
00018
00019 namespace lm {
00020 namespace ngram {
00021 class BinaryFormat;
00022 class ProbingVocabulary;
00023 namespace detail {
00024
00025 inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
00026 uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL);
00027 return ret;
00028 }
00029
00030 #pragma pack(push)
00031 #pragma pack(4)
00032 struct ProbEntry {
00033 uint64_t key;
00034 Prob value;
00035 typedef uint64_t Key;
00036 typedef Prob Value;
00037 uint64_t GetKey() const {
00038 return key;
00039 }
00040 };
00041
00042 #pragma pack(pop)
00043
00044 class LongestPointer {
00045 public:
00046 explicit LongestPointer(const float &to) : to_(&to) {}
00047
00048 LongestPointer() : to_(NULL) {}
00049
00050 bool Found() const {
00051 return to_ != NULL;
00052 }
00053
00054 float Prob() const {
00055 return *to_;
00056 }
00057
00058 private:
00059 const float *to_;
00060 };
00061
00062 template <class Value> class HashedSearch {
00063 public:
00064 typedef uint64_t Node;
00065
00066 typedef typename Value::ProbingProxy UnigramPointer;
00067 typedef typename Value::ProbingProxy MiddlePointer;
00068 typedef ::lm::ngram::detail::LongestPointer LongestPointer;
00069
00070 static const ModelType kModelType = Value::kProbingModelType;
00071 static const bool kDifferentRest = Value::kDifferentRest;
00072 static const unsigned int kVersion = 0;
00073
00074
00075 static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {}
00076
00077 static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
00078 uint64_t ret = Unigram::Size(counts[0]);
00079 for (unsigned char n = 1; n < counts.size() - 1; ++n) {
00080 ret += Middle::Size(counts[n], config.probing_multiplier);
00081 }
00082 return ret + Longest::Size(counts.back(), config.probing_multiplier);
00083 }
00084
00085 uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
00086
00087 void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing);
00088
00089 unsigned char Order() const {
00090 return middle_.size() + 2;
00091 }
00092
00093 typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); }
00094
00095 UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
00096 extend_left = static_cast<uint64_t>(word);
00097 next = extend_left;
00098 UnigramPointer ret(unigram_.Lookup(word));
00099 independent_left = ret.IndependentLeft();
00100 return ret;
00101 }
00102
00103 MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
00104 node = extend_pointer;
00105 return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value);
00106 }
00107
00108 MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
00109 node = CombineWordHash(node, word);
00110 typename Middle::ConstIterator found;
00111 if (!middle_[order_minus_2].Find(node, found)) {
00112 independent_left = true;
00113 return MiddlePointer();
00114 }
00115 extend_pointer = node;
00116 MiddlePointer ret(found->value);
00117 independent_left = ret.IndependentLeft();
00118 return ret;
00119 }
00120
00121 LongestPointer LookupLongest(WordIndex word, const Node &node) const {
00122
00123 typename Longest::ConstIterator found;
00124 if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
00125 return LongestPointer(found->value.prob);
00126 }
00127
00128
00129
00130 bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
00131 assert(begin != end);
00132 node = static_cast<Node>(*begin);
00133 for (const WordIndex *i = begin + 1; i < end; ++i) {
00134 node = CombineWordHash(node, *i);
00135 }
00136 return true;
00137 }
00138
00139 private:
00140
00141 void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
00142
00143 template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
00144
00145 class Unigram {
00146 public:
00147 Unigram() {}
00148
00149 Unigram(void *start, uint64_t count) :
00150 unigram_(static_cast<typename Value::Weights*>(start))
00151 #ifdef DEBUG
00152 , count_(count)
00153 #endif
00154 {}
00155
00156 static uint64_t Size(uint64_t count) {
00157 return (count + 1) * sizeof(typename Value::Weights);
00158 }
00159
00160 const typename Value::Weights &Lookup(WordIndex index) const {
00161 #ifdef DEBUG
00162 assert(index < count_);
00163 #endif
00164 return unigram_[index];
00165 }
00166
00167 typename Value::Weights &Unknown() { return unigram_[0]; }
00168
00169
00170 typename Value::Weights *Raw() { return unigram_; }
00171
00172 private:
00173 typename Value::Weights *unigram_;
00174 #ifdef DEBUG
00175 uint64_t count_;
00176 #endif
00177 };
00178
00179 Unigram unigram_;
00180
00181 typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
00182 std::vector<Middle> middle_;
00183
00184 typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest;
00185 Longest longest_;
00186 };
00187
00188 }
00189 }
00190 }
00191
00192 #endif // LM_SEARCH_HASHED_H