00001 #ifndef LM_TRIE_H
00002 #define LM_TRIE_H
00003
00004 #include "lm/weights.hh"
00005 #include "lm/word_index.hh"
00006 #include "util/bit_packing.hh"
00007
00008 #include <cstddef>
00009
00010 #include <stdint.h>
00011
00012 namespace lm {
00013 namespace ngram {
00014 struct Config;
00015 namespace trie {
00016
00017 struct NodeRange {
00018 uint64_t begin, end;
00019 };
00020
00021
00022 struct UnigramValue {
00023 ProbBackoff weights;
00024 uint64_t next;
00025 uint64_t Next() const { return next; }
00026 };
00027
00028 class UnigramPointer {
00029 public:
00030 explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {}
00031
00032 UnigramPointer() : to_(NULL) {}
00033
00034 bool Found() const { return to_ != NULL; }
00035
00036 float Prob() const { return to_->prob; }
00037 float Backoff() const { return to_->backoff; }
00038 float Rest() const { return Prob(); }
00039
00040 private:
00041 const ProbBackoff *to_;
00042 };
00043
00044 class Unigram {
00045 public:
00046 Unigram() {}
00047
00048 void Init(void *start) {
00049 unigram_ = static_cast<UnigramValue*>(start);
00050 }
00051
00052 static uint64_t Size(uint64_t count) {
00053
00054 return (count + 2) * sizeof(UnigramValue);
00055 }
00056
00057 const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
00058
00059 ProbBackoff &Unknown() { return unigram_[0].weights; }
00060
00061 UnigramValue *Raw() {
00062 return unigram_;
00063 }
00064
00065 UnigramPointer Find(WordIndex word, NodeRange &next) const {
00066 UnigramValue *val = unigram_ + word;
00067 next.begin = val->next;
00068 next.end = (val+1)->next;
00069 return UnigramPointer(val->weights);
00070 }
00071
00072 private:
00073 UnigramValue *unigram_;
00074 };
00075
00076 class BitPacked {
00077 public:
00078 BitPacked() {}
00079
00080 uint64_t InsertIndex() const {
00081 return insert_index_;
00082 }
00083
00084 protected:
00085 static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
00086
00087 void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
00088
00089 uint8_t word_bits_;
00090 uint8_t total_bits_;
00091 uint64_t word_mask_;
00092
00093 uint8_t *base_;
00094
00095 uint64_t insert_index_, max_vocab_;
00096 };
00097
00098 template <class Bhiksha> class BitPackedMiddle : public BitPacked {
00099 public:
00100 static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
00101
00102
00103 BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
00104
00105 util::BitAddress Insert(WordIndex word);
00106
00107 void FinishedLoading(uint64_t next_end, const Config &config);
00108
00109 util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
00110
00111 util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
00112 uint64_t addr = pointer * total_bits_;
00113 addr += word_bits_;
00114 bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range);
00115 return util::BitAddress(base_, addr);
00116 }
00117
00118 private:
00119 uint8_t quant_bits_;
00120 Bhiksha bhiksha_;
00121
00122 const BitPacked *next_source_;
00123 };
00124
00125 class BitPackedLongest : public BitPacked {
00126 public:
00127 static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
00128 return BaseSize(entries, max_vocab, quant_bits);
00129 }
00130
00131 BitPackedLongest() {}
00132
00133 void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) {
00134 BaseInit(base, max_vocab, quant_bits);
00135 }
00136
00137 util::BitAddress Insert(WordIndex word);
00138
00139 util::BitAddress Find(WordIndex word, const NodeRange &node) const;
00140 };
00141
00142 }
00143 }
00144 }
00145
00146 #endif // LM_TRIE_H