00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #ifndef LM_LEFT_H
00039 #define LM_LEFT_H
00040
00041 #include "lm/max_order.hh"
00042 #include "lm/state.hh"
00043 #include "lm/return.hh"
00044
00045 #include "util/murmur_hash.hh"
00046
00047 #include <algorithm>
00048
00049 namespace lm {
00050 namespace ngram {
00051
00052 template <class M> class RuleScore {
00053 public:
00054 explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {
00055 out.left.length = 0;
00056 out.right.length = 0;
00057 }
00058
00059 void BeginSentence() {
00060 out_->right = model_.BeginSentenceState();
00061
00062 left_done_ = true;
00063 }
00064
00065 void Terminal(WordIndex word) {
00066 State copy(out_->right);
00067 FullScoreReturn ret(model_.FullScore(copy, word, out_->right));
00068 if (left_done_) { prob_ += ret.prob; return; }
00069 if (ret.independent_left) {
00070 prob_ += ret.prob;
00071 left_done_ = true;
00072 return;
00073 }
00074 out_->left.pointers[out_->left.length++] = ret.extend_left;
00075 prob_ += ret.rest;
00076 if (out_->right.length != copy.length + 1)
00077 left_done_ = true;
00078 }
00079
00080
00081 void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
00082 prob_ = prob;
00083 *out_ = in;
00084 left_done_ = in.left.full;
00085 }
00086
00087 void NonTerminal(const ChartState &in, float prob = 0.0) {
00088 prob_ += prob;
00089
00090 if (!in.left.length) {
00091 if (in.left.full) {
00092 for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;
00093 left_done_ = true;
00094 out_->right = in.right;
00095 }
00096 return;
00097 }
00098
00099 if (!out_->right.length) {
00100 out_->right = in.right;
00101 if (left_done_) {
00102 prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
00103 return;
00104 }
00105 if (out_->left.length) {
00106 left_done_ = true;
00107 } else {
00108 out_->left = in.left;
00109 left_done_ = in.left.full;
00110 }
00111 return;
00112 }
00113
00114 float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
00115 float *back = backoffs, *back2 = backoffs2;
00116 unsigned char next_use = out_->right.length;
00117
00118
00119 if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;
00120
00121
00122 for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
00123 if (ExtendLeft(in, next_use, extend_length, back, back2)) return;
00124 std::swap(back, back2);
00125 }
00126
00127 if (in.left.full) {
00128 for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
00129 left_done_ = true;
00130 out_->right = in.right;
00131 return;
00132 }
00133
00134
00135 if (in.right.length < in.left.length) {
00136 out_->right = in.right;
00137 return;
00138 }
00139
00140
00141 for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {
00142 *(i + in.right.length) = *i;
00143 }
00144
00145 std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);
00146
00147 std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff);
00148 std::copy(back, back + next_use, out_->right.backoff + in.right.length);
00149 out_->right.length = in.right.length + next_use;
00150 }
00151
00152 float Finish() {
00153
00154 out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);
00155 return prob_;
00156 }
00157
00158 void Reset() {
00159 prob_ = 0.0;
00160 left_done_ = false;
00161 out_->left.length = 0;
00162 out_->right.length = 0;
00163 }
00164 void Reset(ChartState &replacement) {
00165 out_ = &replacement;
00166 Reset();
00167 }
00168
00169 private:
00170 bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
00171 ProcessRet(model_.ExtendLeft(
00172 out_->right.words, out_->right.words + next_use,
00173 back_in,
00174 in.left.pointers[extend_length - 1], extend_length,
00175 back_out,
00176 next_use));
00177 if (next_use != out_->right.length) {
00178 left_done_ = true;
00179 if (!next_use) {
00180
00181 out_->right = in.right;
00182 prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
00183 return true;
00184 }
00185 }
00186
00187 return false;
00188 }
00189
00190 void ProcessRet(const FullScoreReturn &ret) {
00191 if (left_done_) {
00192 prob_ += ret.prob;
00193 return;
00194 }
00195 if (ret.independent_left) {
00196 prob_ += ret.prob;
00197 left_done_ = true;
00198 return;
00199 }
00200 out_->left.pointers[out_->left.length++] = ret.extend_left;
00201 prob_ += ret.rest;
00202 }
00203
00204 const M &model_;
00205
00206 ChartState *out_;
00207
00208 bool left_done_;
00209
00210 float prob_;
00211 };
00212
00213 }
00214 }
00215
00216 #endif // LM_LEFT_H