00001 #include "lm/builder/initial_probabilities.hh"
00002 
00003 #include "lm/builder/discount.hh"
00004 #include "lm/builder/hash_gamma.hh"
00005 #include "lm/builder/payload.hh"
00006 #include "lm/common/special.hh"
00007 #include "lm/common/ngram_stream.hh"
00008 #include "util/murmur_hash.hh"
00009 #include "util/file.hh"
00010 #include "util/stream/chain.hh"
00011 #include "util/stream/io.hh"
00012 #include "util/stream/stream.hh"
00013 
00014 #include <vector>
00015 
00016 namespace lm { namespace builder {
00017 
00018 namespace {
00019 struct BufferEntry {
00020   
00021   float gamma;
00022   
00023   float denominator;
00024 };
00025 
00026 struct HashBufferEntry : public BufferEntry {
00027   
00028   uint64_t hash_value;
00029 };
00030 
00031 
00032 
00033 
00034 class PruneNGramStream {
00035   public:
00036     PruneNGramStream(const util::stream::ChainPosition &position, const SpecialVocab &specials) :
00037       current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
00038       dest_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
00039       currentCount_(0),
00040       block_(position),
00041       specials_(specials)
00042     {
00043       StartBlock();
00044     }
00045 
00046     NGram<BuildingPayload> &operator*() { return current_; }
00047     NGram<BuildingPayload> *operator->() { return ¤t_; }
00048 
00049     operator bool() const {
00050       return block_;
00051     }
00052 
00053     PruneNGramStream &operator++() {
00054       assert(block_);
00055       if(UTIL_UNLIKELY(current_.Order() == 1 && specials_.IsSpecial(*current_.begin())))
00056         dest_.NextInMemory();
00057       else if(currentCount_ > 0) {
00058         if(dest_.Base() < current_.Base()) {
00059           memcpy(dest_.Base(), current_.Base(), current_.TotalSize());
00060         }
00061         dest_.NextInMemory();
00062       }
00063 
00064       current_.NextInMemory();
00065 
00066       uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
00067       if (current_.Base() == block_base + block_->ValidSize()) {
00068         block_->SetValidSize(dest_.Base() - block_base);
00069         ++block_;
00070         StartBlock();
00071         if (block_) {
00072           currentCount_ = current_.Value().CutoffCount();
00073         }
00074       } else {
00075         currentCount_ = current_.Value().CutoffCount();
00076       }
00077 
00078       return *this;
00079     }
00080 
00081   private:
00082     void StartBlock() {
00083       for (; ; ++block_) {
00084         if (!block_) return;
00085         if (block_->ValidSize()) break;
00086       }
00087       current_.ReBase(block_->Get());
00088       currentCount_ = current_.Value().CutoffCount();
00089 
00090       dest_.ReBase(block_->Get());
00091     }
00092 
00093     NGram<BuildingPayload> current_; 
00094     NGram<BuildingPayload> dest_;    
00095 
00096     uint64_t currentCount_;
00097 
00098     util::stream::Link block_;
00099 
00100     const SpecialVocab specials_;
00101 };
00102 
00103 
00104 class OnlyGamma {
00105   public:
00106     explicit OnlyGamma(bool pruning) : pruning_(pruning) {}
00107 
00108     void Run(const util::stream::ChainPosition &position) {
00109       for (util::stream::Link block_it(position); block_it; ++block_it) {
00110         if(pruning_) {
00111           const HashBufferEntry *in = static_cast<const HashBufferEntry*>(block_it->Get());
00112           const HashBufferEntry *end = static_cast<const HashBufferEntry*>(block_it->ValidEnd());
00113 
00114           
00115           
00116           HashGamma *out = static_cast<HashGamma*>(block_it->Get());
00117           for (; in < end; out += 1, in += 1) {
00118             
00119             float gamma_buf = in->gamma;
00120             uint64_t hash_buf = in->hash_value;
00121 
00122             out->gamma = gamma_buf;
00123             out->hash_value = hash_buf;
00124           }
00125           block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry));
00126         }
00127         else {
00128           float *out = static_cast<float*>(block_it->Get());
00129           const float *in = out;
00130           const float *end = static_cast<const float*>(block_it->ValidEnd());
00131           for (out += 1, in += 2; in < end; out += 1, in += 2) {
00132             *out = *in;
00133           }
00134           block_it->SetValidSize(block_it->ValidSize() / 2);
00135         }
00136       }
00137     }
00138 
00139     private:
00140       bool pruning_;
00141 };
00142 
00143 class AddRight {
00144   public:
00145     AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning)
00146       : discount_(discount), input_(input), pruning_(pruning) {}
00147 
00148     void Run(const util::stream::ChainPosition &output) {
00149       NGramStream<BuildingPayload> in(input_);
00150       util::stream::Stream out(output);
00151 
00152       std::vector<WordIndex> previous(in->Order() - 1);
00153       
00154       void *const previous_raw = previous.empty() ? NULL : static_cast<void*>(&previous[0]);
00155       const std::size_t size = sizeof(WordIndex) * previous.size();
00156 
00157       for(; in; ++out) {
00158         memcpy(previous_raw, in->begin(), size);
00159         uint64_t denominator = 0;
00160         uint64_t normalizer = 0;
00161 
00162         uint64_t counts[4];
00163         memset(counts, 0, sizeof(counts));
00164         do {
00165           denominator += in->Value().UnmarkedCount();
00166 
00167           
00168           
00169           normalizer += in->Value().UnmarkedCount() - in->Value().CutoffCount();
00170 
00171           
00172           
00173           
00174           if(in->Value().CutoffCount() > 0)
00175             ++counts[std::min(in->Value().CutoffCount(), static_cast<uint64_t>(3))];
00176 
00177         } while (++in && !memcmp(previous_raw, in->begin(), size));
00178 
00179         BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
00180         entry.denominator = static_cast<float>(denominator);
00181         entry.gamma = 0.0;
00182         for (unsigned i = 1; i <= 3; ++i) {
00183           entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
00184         }
00185 
00186         
00187         entry.gamma += normalizer;
00188 
00189         entry.gamma /= entry.denominator;
00190 
00191         if(pruning_) {
00192           
00193           
00194           static_cast<HashBufferEntry*>(&entry)->hash_value = util::MurmurHashNative(previous_raw, size);
00195         }
00196       }
00197       out.Poison();
00198     }
00199 
00200   private:
00201     const Discount &discount_;
00202     const util::stream::ChainPosition input_;
00203     bool pruning_;
00204 };
00205 
00206 class MergeRight {
00207   public:
00208     MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount, const SpecialVocab &specials)
00209       : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount), specials_(specials) {}
00210 
00211     
00212     
00213     void Run(const util::stream::ChainPosition &primary) {
00214       util::stream::Stream summed(from_adder_);
00215 
00216       PruneNGramStream grams(primary, specials_);
00217 
00218       
00219       if (grams->Order() == 1) {
00220         BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get()));
00221         
00222         assert(*grams->begin() == kUNK);
00223         float gamma_assign;
00224         if (interpolate_unigrams_) {
00225           
00226           gamma_assign = sums.gamma;
00227           grams->Value().uninterp.prob = 0.0;
00228         } else {
00229           
00230           gamma_assign = 0.0;
00231           grams->Value().uninterp.prob = sums.gamma;
00232         }
00233         grams->Value().uninterp.gamma = gamma_assign;
00234 
00235         for (++grams; *grams->begin() != specials_.BOS(); ++grams) {
00236           grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator;
00237           grams->Value().uninterp.gamma = gamma_assign;
00238         }
00239 
00240         
00241         
00242         
00243         assert(*grams->begin() == specials_.BOS());
00244         grams->Value().uninterp.prob = 1.0;
00245         grams->Value().uninterp.gamma = 0.0;
00246 
00247         while (++grams) {
00248           grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator;
00249           grams->Value().uninterp.gamma = gamma_assign;
00250         }
00251         ++summed;
00252         return;
00253       }
00254 
00255       std::vector<WordIndex> previous(grams->Order() - 1);
00256       const std::size_t size = sizeof(WordIndex) * previous.size();
00257       for (; grams; ++summed) {
00258         memcpy(&previous[0], grams->begin(), size);
00259         const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
00260 
00261         do {
00262           BuildingPayload &pay = grams->Value();
00263           pay.uninterp.prob = discount_.Apply(grams->Value().UnmarkedCount()) / sums.denominator;
00264           pay.uninterp.gamma = sums.gamma;
00265         } while (++grams && !memcmp(&previous[0], grams->begin(), size));
00266       }
00267     }
00268 
00269   private:
00270     bool interpolate_unigrams_;
00271     util::stream::ChainPosition from_adder_;
00272     Discount discount_;
00273     const SpecialVocab specials_;
00274 };
00275 
00276 } 
00277 
00278 void InitialProbabilities(
00279     const InitialProbabilitiesConfig &config,
00280     const std::vector<Discount> &discounts,
00281     util::stream::Chains &primary,
00282     util::stream::Chains &second_in,
00283     util::stream::Chains &gamma_out,
00284     const std::vector<uint64_t> &prune_thresholds,
00285     bool prune_vocab,
00286     const SpecialVocab &specials) {
00287   for (size_t i = 0; i < primary.size(); ++i) {
00288     util::stream::ChainConfig gamma_config = config.adder_out;
00289     if(prune_vocab || prune_thresholds[i] > 0)
00290       gamma_config.entry_size = sizeof(HashBufferEntry);
00291     else
00292       gamma_config.entry_size = sizeof(BufferEntry);
00293 
00294     util::stream::ChainPosition second(second_in[i].Add());
00295     second_in[i] >> util::stream::kRecycle;
00296     gamma_out.push_back(gamma_config);
00297     gamma_out[i] >> AddRight(discounts[i], second, prune_vocab || prune_thresholds[i] > 0);
00298 
00299     primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i], specials);
00300 
00301     
00302     if (i) gamma_out[i] >> OnlyGamma(prune_vocab || prune_thresholds[i] > 0);
00303   }
00304 }
00305 
00306 }}