00001 #include "lm/builder/interpolate.hh"
00002
00003 #include "lm/builder/hash_gamma.hh"
00004 #include "lm/builder/payload.hh"
00005 #include "lm/common/compare.hh"
00006 #include "lm/common/joint_order.hh"
00007 #include "lm/common/ngram_stream.hh"
00008 #include "lm/lm_exception.hh"
00009 #include "util/fixed_array.hh"
00010 #include "util/murmur_hash.hh"
00011
00012 #include <iostream>
00013 #include <cassert>
00014 #include <cmath>
00015
00016 namespace lm { namespace builder {
00017 namespace {
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 class OutputQ {
00034 public:
00035 explicit OutputQ(std::size_t order) : q_delta_(order) {}
00036
00037 void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) {
00038 float &q_del = q_delta_[order_minus_1];
00039 if (order_minus_1) {
00040
00041 q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff;
00042 } else {
00043 q_del = full_backoff;
00044 }
00045 out.prob = log10f(out.prob * q_del);
00046
00047 out.backoff = 0.0;
00048 }
00049
00050 private:
00051
00052
00053 std::vector<float> q_delta_;
00054 };
00055
00056
00057 class OutputProbBackoff {
00058 public:
00059 explicit OutputProbBackoff(std::size_t ) {}
00060
00061 void Gram(unsigned , float full_backoff, ProbBackoff &out) const {
00062
00063 out.prob = std::min(0.0f, log10f(out.prob));
00064 out.backoff = log10f(full_backoff);
00065 }
00066 };
00067
00068 template <class Output> class Callback {
00069 public:
00070 Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, const SpecialVocab &specials)
00071 : backoffs_(backoffs.size()), probs_(backoffs.size() + 2),
00072 prune_thresholds_(prune_thresholds),
00073 prune_vocab_(prune_vocab),
00074 output_(backoffs.size() + 1 ),
00075 specials_(specials) {
00076 probs_[0] = uniform_prob;
00077 for (std::size_t i = 0; i < backoffs.size(); ++i) {
00078 backoffs_.push_back(backoffs[i]);
00079 }
00080 }
00081
00082 ~Callback() {
00083 for (std::size_t i = 0; i < backoffs_.size(); ++i) {
00084 if(prune_vocab_ || prune_thresholds_[i + 1] > 0)
00085 while(backoffs_[i])
00086 ++backoffs_[i];
00087
00088 if (backoffs_[i]) {
00089 std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
00090 abort();
00091 }
00092 }
00093 }
00094
00095 void Enter(unsigned order_minus_1, void *data) {
00096 NGram<BuildingPayload> gram(data, order_minus_1 + 1);
00097 BuildingPayload &pay = gram.Value();
00098 pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
00099 probs_[order_minus_1 + 1] = pay.complete.prob;
00100
00101 float out_backoff;
00102 if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != specials_.UNK() && *(gram.end() - 1) != specials_.EOS() && backoffs_[order_minus_1]) {
00103 if(prune_vocab_ || prune_thresholds_[order_minus_1 + 1] > 0) {
00104
00105 uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
00106
00107 const HashGamma *hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
00108 while(current_hash != hashed_backoff->hash_value && ++backoffs_[order_minus_1])
00109 hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
00110
00111 if(current_hash == hashed_backoff->hash_value) {
00112 out_backoff = hashed_backoff->gamma;
00113 ++backoffs_[order_minus_1];
00114 } else {
00115
00116 out_backoff = 1.0;
00117 }
00118 } else {
00119 out_backoff = *static_cast<const float*>(backoffs_[order_minus_1].Get());
00120 ++backoffs_[order_minus_1];
00121 }
00122 } else {
00123
00124 out_backoff = 1.0;
00125 }
00126
00127 output_.Gram(order_minus_1, out_backoff, pay.complete);
00128 }
00129
00130 void Exit(unsigned, void *) const {}
00131
00132 private:
00133 util::FixedArray<util::stream::Stream> backoffs_;
00134
00135 std::vector<float> probs_;
00136 const std::vector<uint64_t>& prune_thresholds_;
00137 bool prune_vocab_;
00138
00139 Output output_;
00140 const SpecialVocab specials_;
00141 };
00142 }
00143
00144 Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials)
00145 : uniform_prob_(1.0 / static_cast<float>(vocab_size)),
00146 backoffs_(backoffs),
00147 prune_thresholds_(prune_thresholds),
00148 prune_vocab_(prune_vocab),
00149 output_q_(output_q),
00150 specials_(specials) {}
00151
00152
00153 void Interpolate::Run(const util::stream::ChainPositions &positions) {
00154 assert(positions.size() == backoffs_.size() + 1);
00155 if (output_q_) {
00156 typedef Callback<OutputQ> C;
00157 C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_);
00158 JointOrder<C, SuffixOrder>(positions, callback);
00159 } else {
00160 typedef Callback<OutputProbBackoff> C;
00161 C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_);
00162 JointOrder<C, SuffixOrder>(positions, callback);
00163 }
00164 }
00165
00166 }}