00001 #ifndef LM_VALUE_BUILD_H
00002 #define LM_VALUE_BUILD_H
00003
00004 #include "lm/weights.hh"
00005 #include "lm/word_index.hh"
00006 #include "util/bit_packing.hh"
00007
00008 #include <vector>
00009
00010 namespace lm {
00011 namespace ngram {
00012
00013 struct Config;
00014 struct BackoffValue;
00015 struct RestValue;
00016
00017 class NoRestBuild {
00018 public:
00019 typedef BackoffValue Value;
00020
00021 NoRestBuild() {}
00022
00023 void SetRest(const WordIndex *, unsigned int, const Prob &) const {}
00024 void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {}
00025
00026 template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const {
00027 util::UnsetSign(weights.prob);
00028 return false;
00029 }
00030
00031
00032 const static bool kMarkEvenLower = false;
00033 };
00034
00035 class MaxRestBuild {
00036 public:
00037 typedef RestValue Value;
00038
00039 MaxRestBuild() {}
00040
00041 void SetRest(const WordIndex *, unsigned int, const Prob &) const {}
00042 void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const {
00043 weights.rest = weights.prob;
00044 util::SetSign(weights.rest);
00045 }
00046
00047 bool MarkExtends(RestWeights &weights, const RestWeights &to) const {
00048 util::UnsetSign(weights.prob);
00049 if (weights.rest >= to.rest) return false;
00050 weights.rest = to.rest;
00051 return true;
00052 }
00053 bool MarkExtends(RestWeights &weights, const Prob &to) const {
00054 util::UnsetSign(weights.prob);
00055 if (weights.rest >= to.prob) return false;
00056 weights.rest = to.prob;
00057 return true;
00058 }
00059
00060
00061 const static bool kMarkEvenLower = true;
00062 };
00063
00064 template <class Model> class LowerRestBuild {
00065 public:
00066 typedef RestValue Value;
00067
00068 LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab);
00069
00070 ~LowerRestBuild();
00071
00072 void SetRest(const WordIndex *, unsigned int, const Prob &) const {}
00073 void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const {
00074 typename Model::State ignored;
00075 if (n == 1) {
00076 weights.rest = unigrams_[*vocab_ids];
00077 } else {
00078 weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob;
00079 }
00080 }
00081
00082 template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const {
00083 util::UnsetSign(weights.prob);
00084 return false;
00085 }
00086
00087 const static bool kMarkEvenLower = false;
00088
00089 std::vector<float> unigrams_;
00090
00091 std::vector<const Model*> models_;
00092 };
00093
00094 }
00095 }
00096
00097 #endif // LM_VALUE_BUILD_H