00001 #include "lm/builder/initial_probabilities.hh"
00002
00003 #include "lm/builder/discount.hh"
00004 #include "lm/builder/hash_gamma.hh"
00005 #include "lm/builder/payload.hh"
00006 #include "lm/common/special.hh"
00007 #include "lm/common/ngram_stream.hh"
00008 #include "util/murmur_hash.hh"
00009 #include "util/file.hh"
00010 #include "util/stream/chain.hh"
00011 #include "util/stream/io.hh"
00012 #include "util/stream/stream.hh"
00013
00014 #include <vector>
00015
00016 namespace lm { namespace builder {
00017
00018 namespace {
00019 struct BufferEntry {
00020
00021 float gamma;
00022
00023 float denominator;
00024 };
00025
00026 struct HashBufferEntry : public BufferEntry {
00027
00028 uint64_t hash_value;
00029 };
00030
00031
00032
00033
00034 class PruneNGramStream {
00035 public:
00036 PruneNGramStream(const util::stream::ChainPosition &position, const SpecialVocab &specials) :
00037 current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
00038 dest_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
00039 currentCount_(0),
00040 block_(position),
00041 specials_(specials)
00042 {
00043 StartBlock();
00044 }
00045
00046 NGram<BuildingPayload> &operator*() { return current_; }
00047 NGram<BuildingPayload> *operator->() { return ¤t_; }
00048
00049 operator bool() const {
00050 return block_;
00051 }
00052
00053 PruneNGramStream &operator++() {
00054 assert(block_);
00055 if(UTIL_UNLIKELY(current_.Order() == 1 && specials_.IsSpecial(*current_.begin())))
00056 dest_.NextInMemory();
00057 else if(currentCount_ > 0) {
00058 if(dest_.Base() < current_.Base()) {
00059 memcpy(dest_.Base(), current_.Base(), current_.TotalSize());
00060 }
00061 dest_.NextInMemory();
00062 }
00063
00064 current_.NextInMemory();
00065
00066 uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
00067 if (current_.Base() == block_base + block_->ValidSize()) {
00068 block_->SetValidSize(dest_.Base() - block_base);
00069 ++block_;
00070 StartBlock();
00071 if (block_) {
00072 currentCount_ = current_.Value().CutoffCount();
00073 }
00074 } else {
00075 currentCount_ = current_.Value().CutoffCount();
00076 }
00077
00078 return *this;
00079 }
00080
00081 private:
00082 void StartBlock() {
00083 for (; ; ++block_) {
00084 if (!block_) return;
00085 if (block_->ValidSize()) break;
00086 }
00087 current_.ReBase(block_->Get());
00088 currentCount_ = current_.Value().CutoffCount();
00089
00090 dest_.ReBase(block_->Get());
00091 }
00092
00093 NGram<BuildingPayload> current_;
00094 NGram<BuildingPayload> dest_;
00095
00096 uint64_t currentCount_;
00097
00098 util::stream::Link block_;
00099
00100 const SpecialVocab specials_;
00101 };
00102
00103
00104 class OnlyGamma {
00105 public:
00106 explicit OnlyGamma(bool pruning) : pruning_(pruning) {}
00107
00108 void Run(const util::stream::ChainPosition &position) {
00109 for (util::stream::Link block_it(position); block_it; ++block_it) {
00110 if(pruning_) {
00111 const HashBufferEntry *in = static_cast<const HashBufferEntry*>(block_it->Get());
00112 const HashBufferEntry *end = static_cast<const HashBufferEntry*>(block_it->ValidEnd());
00113
00114
00115
00116 HashGamma *out = static_cast<HashGamma*>(block_it->Get());
00117 for (; in < end; out += 1, in += 1) {
00118
00119 float gamma_buf = in->gamma;
00120 uint64_t hash_buf = in->hash_value;
00121
00122 out->gamma = gamma_buf;
00123 out->hash_value = hash_buf;
00124 }
00125 block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry));
00126 }
00127 else {
00128 float *out = static_cast<float*>(block_it->Get());
00129 const float *in = out;
00130 const float *end = static_cast<const float*>(block_it->ValidEnd());
00131 for (out += 1, in += 2; in < end; out += 1, in += 2) {
00132 *out = *in;
00133 }
00134 block_it->SetValidSize(block_it->ValidSize() / 2);
00135 }
00136 }
00137 }
00138
00139 private:
00140 bool pruning_;
00141 };
00142
00143 class AddRight {
00144 public:
00145 AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning)
00146 : discount_(discount), input_(input), pruning_(pruning) {}
00147
00148 void Run(const util::stream::ChainPosition &output) {
00149 NGramStream<BuildingPayload> in(input_);
00150 util::stream::Stream out(output);
00151
00152 std::vector<WordIndex> previous(in->Order() - 1);
00153
00154 void *const previous_raw = previous.empty() ? NULL : static_cast<void*>(&previous[0]);
00155 const std::size_t size = sizeof(WordIndex) * previous.size();
00156
00157 for(; in; ++out) {
00158 memcpy(previous_raw, in->begin(), size);
00159 uint64_t denominator = 0;
00160 uint64_t normalizer = 0;
00161
00162 uint64_t counts[4];
00163 memset(counts, 0, sizeof(counts));
00164 do {
00165 denominator += in->Value().UnmarkedCount();
00166
00167
00168
00169 normalizer += in->Value().UnmarkedCount() - in->Value().CutoffCount();
00170
00171
00172
00173
00174 if(in->Value().CutoffCount() > 0)
00175 ++counts[std::min(in->Value().CutoffCount(), static_cast<uint64_t>(3))];
00176
00177 } while (++in && !memcmp(previous_raw, in->begin(), size));
00178
00179 BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
00180 entry.denominator = static_cast<float>(denominator);
00181 entry.gamma = 0.0;
00182 for (unsigned i = 1; i <= 3; ++i) {
00183 entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
00184 }
00185
00186
00187 entry.gamma += normalizer;
00188
00189 entry.gamma /= entry.denominator;
00190
00191 if(pruning_) {
00192
00193
00194 static_cast<HashBufferEntry*>(&entry)->hash_value = util::MurmurHashNative(previous_raw, size);
00195 }
00196 }
00197 out.Poison();
00198 }
00199
00200 private:
00201 const Discount &discount_;
00202 const util::stream::ChainPosition input_;
00203 bool pruning_;
00204 };
00205
00206 class MergeRight {
00207 public:
00208 MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount, const SpecialVocab &specials)
00209 : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount), specials_(specials) {}
00210
00211
00212
00213 void Run(const util::stream::ChainPosition &primary) {
00214 util::stream::Stream summed(from_adder_);
00215
00216 PruneNGramStream grams(primary, specials_);
00217
00218
00219 if (grams->Order() == 1) {
00220 BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get()));
00221
00222 assert(*grams->begin() == kUNK);
00223 float gamma_assign;
00224 if (interpolate_unigrams_) {
00225
00226 gamma_assign = sums.gamma;
00227 grams->Value().uninterp.prob = 0.0;
00228 } else {
00229
00230 gamma_assign = 0.0;
00231 grams->Value().uninterp.prob = sums.gamma;
00232 }
00233 grams->Value().uninterp.gamma = gamma_assign;
00234
00235 for (++grams; *grams->begin() != specials_.BOS(); ++grams) {
00236 grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator;
00237 grams->Value().uninterp.gamma = gamma_assign;
00238 }
00239
00240
00241
00242
00243 assert(*grams->begin() == specials_.BOS());
00244 grams->Value().uninterp.prob = 1.0;
00245 grams->Value().uninterp.gamma = 0.0;
00246
00247 while (++grams) {
00248 grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator;
00249 grams->Value().uninterp.gamma = gamma_assign;
00250 }
00251 ++summed;
00252 return;
00253 }
00254
00255 std::vector<WordIndex> previous(grams->Order() - 1);
00256 const std::size_t size = sizeof(WordIndex) * previous.size();
00257 for (; grams; ++summed) {
00258 memcpy(&previous[0], grams->begin(), size);
00259 const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
00260
00261 do {
00262 BuildingPayload &pay = grams->Value();
00263 pay.uninterp.prob = discount_.Apply(grams->Value().UnmarkedCount()) / sums.denominator;
00264 pay.uninterp.gamma = sums.gamma;
00265 } while (++grams && !memcmp(&previous[0], grams->begin(), size));
00266 }
00267 }
00268
00269 private:
00270 bool interpolate_unigrams_;
00271 util::stream::ChainPosition from_adder_;
00272 Discount discount_;
00273 const SpecialVocab specials_;
00274 };
00275
00276 }
00277
00278 void InitialProbabilities(
00279 const InitialProbabilitiesConfig &config,
00280 const std::vector<Discount> &discounts,
00281 util::stream::Chains &primary,
00282 util::stream::Chains &second_in,
00283 util::stream::Chains &gamma_out,
00284 const std::vector<uint64_t> &prune_thresholds,
00285 bool prune_vocab,
00286 const SpecialVocab &specials) {
00287 for (size_t i = 0; i < primary.size(); ++i) {
00288 util::stream::ChainConfig gamma_config = config.adder_out;
00289 if(prune_vocab || prune_thresholds[i] > 0)
00290 gamma_config.entry_size = sizeof(HashBufferEntry);
00291 else
00292 gamma_config.entry_size = sizeof(BufferEntry);
00293
00294 util::stream::ChainPosition second(second_in[i].Add());
00295 second_in[i] >> util::stream::kRecycle;
00296 gamma_out.push_back(gamma_config);
00297 gamma_out[i] >> AddRight(discounts[i], second, prune_vocab || prune_thresholds[i] > 0);
00298
00299 primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i], specials);
00300
00301
00302 if (i) gamma_out[i] >> OnlyGamma(prune_vocab || prune_thresholds[i] > 0);
00303 }
00304 }
00305
00306 }}