00001 #include "lm/builder/pipeline.hh"
00002
00003 #include "lm/builder/adjust_counts.hh"
00004 #include "lm/builder/combine_counts.hh"
00005 #include "lm/builder/corpus_count.hh"
00006 #include "lm/builder/hash_gamma.hh"
00007 #include "lm/builder/initial_probabilities.hh"
00008 #include "lm/builder/interpolate.hh"
00009 #include "lm/builder/output.hh"
00010 #include "lm/common/compare.hh"
00011 #include "lm/common/renumber.hh"
00012
00013 #include "lm/sizes.hh"
00014 #include "lm/vocab.hh"
00015
00016 #include "util/exception.hh"
00017 #include "util/file.hh"
00018 #include "util/stream/io.hh"
00019
00020 #include <algorithm>
00021 #include <iostream>
00022 #include <fstream>
00023 #include <vector>
00024
00025 namespace lm { namespace builder {
00026
00027 using util::stream::Sorts;
00028
00029 namespace {
00030
00031 void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts) {
00032 std::cerr << "Statistics:\n";
00033 for (size_t i = 0; i < counts.size(); ++i) {
00034 std::cerr << (i + 1) << ' ' << counts_pruned[i];
00035 if(counts[i] != counts_pruned[i])
00036 std::cerr << "/" << counts[i];
00037
00038 for (size_t d = 1; d <= 3; ++d)
00039 std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d];
00040 std::cerr << '\n';
00041 }
00042 }
00043
00044 class Master {
00045 public:
00046 explicit Master(PipelineConfig &config, unsigned output_steps)
00047 : config_(config), chains_(config.order), unigrams_(util::MakeTemp(config_.TempPrefix())), steps_(output_steps + 4) {
00048 config_.minimum_block = std::max(NGram<BuildingPayload>::TotalSize(config_.order), config_.minimum_block);
00049 }
00050
00051 const PipelineConfig &Config() const { return config_; }
00052
00053 util::stream::Chains &MutableChains() { return chains_; }
00054
00055 template <class T> Master &operator>>(const T &worker) {
00056 chains_ >> worker;
00057 return *this;
00058 }
00059
00060
00061 void InitForAdjust(util::stream::Sort<SuffixOrder, CombineCounts> &ngrams, WordIndex types, std::size_t subtract_for_numbering) {
00062 const std::size_t each_order_min = config_.minimum_block * config_.block_count;
00063
00064 const std::size_t min_chains = (config_.order - 1) * each_order_min +
00065 std::min(types * NGram<BuildingPayload>::TotalSize(1), each_order_min);
00066
00067 const std::size_t total = std::max<std::size_t>(config_.TotalMemory(), min_chains + subtract_for_numbering + config_.minimum_block);
00068
00069 const std::size_t merge_using = ngrams.Merge(std::min(total - min_chains - subtract_for_numbering, ngrams.DefaultLazy()));
00070
00071 std::vector<uint64_t> count_bounds(1, types);
00072 CreateChains(total - merge_using - subtract_for_numbering, count_bounds);
00073 ngrams.Output(chains_.back(), merge_using);
00074 }
00075
00076
00077 void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) {
00078 bool unigrams_are_sorted = !config_.renumber_vocabulary;
00079
00080 for (std::size_t i = 0; i < config_.order - unigrams_are_sorted; ++i) {
00081 sorts[i].Merge(0);
00082 }
00083
00084 CreateChains(config_.TotalMemory(), counts);
00085 chains_.back().ActivateProgress();
00086 if (unigrams_are_sorted) {
00087 chains_[0] >> unigrams_.Source();
00088 second_config.entry_size = NGram<BuildingPayload>::TotalSize(1);
00089 second.push_back(second_config);
00090 second.back() >> unigrams_.Source();
00091 }
00092 for (std::size_t i = unigrams_are_sorted; i < config_.order; ++i) {
00093 util::scoped_fd fd(sorts[i - unigrams_are_sorted].StealCompleted());
00094 chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get()));
00095 chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true);
00096 second_config.entry_size = NGram<BuildingPayload>::TotalSize(i + 1);
00097 second.push_back(second_config);
00098 second.back() >> util::stream::PRead(fd.release(), true);
00099 }
00100 }
00101
00102
00103 template <class Compare> void MaximumLazyInput(const std::vector<uint64_t> &counts, Sorts<Compare> &sorts) {
00104
00105 std::size_t min_chains = 0;
00106 for (std::size_t i = 0; i < config_.order; ++i) {
00107 min_chains += std::min(counts[i] * NGram<BuildingPayload>::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
00108 }
00109 std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains);
00110 std::vector<std::size_t> laziness;
00111
00112 for (util::stream::Sort<SuffixOrder> *i = sorts.end() - 1; i >= sorts.begin(); --i) {
00113 laziness.push_back(i->Merge(for_merge));
00114 assert(for_merge >= laziness.back());
00115 for_merge -= laziness.back();
00116 }
00117 std::reverse(laziness.begin(), laziness.end());
00118
00119 CreateChains(for_merge + min_chains, counts);
00120 chains_.back().ActivateProgress();
00121 chains_[0] >> unigrams_.Source();
00122 for (std::size_t i = 1; i < config_.order; ++i) {
00123 sorts[i - 1].Output(chains_[i], laziness[i - 1]);
00124 }
00125 }
00126
00127 template <class Compare> void SetupSorts(Sorts<Compare> &sorts, bool exclude_unigrams) {
00128 sorts.Init(config_.order - exclude_unigrams);
00129
00130 if (exclude_unigrams) chains_[0] >> unigrams_.Sink();
00131 for (std::size_t i = exclude_unigrams; i < config_.order; ++i) {
00132 sorts.push_back(chains_[i], config_.sort, Compare(i + 1));
00133 }
00134 chains_.Wait(true);
00135 }
00136
00137 unsigned int Steps() const { return steps_; }
00138
00139 private:
00140
00141
00142 void CreateChains(std::size_t remaining_mem, const std::vector<uint64_t> &count_bounds) {
00143 std::vector<std::size_t> assignments;
00144 assignments.reserve(config_.order);
00145
00146 for (std::size_t i = 0; i < count_bounds.size(); ++i) {
00147 assignments.push_back(static_cast<std::size_t>(std::min(
00148 static_cast<uint64_t>(remaining_mem),
00149 count_bounds[i] * static_cast<uint64_t>(NGram<BuildingPayload>::TotalSize(i + 1)))));
00150 }
00151 assignments.resize(config_.order, remaining_mem);
00152
00153
00154
00155 std::vector<float> portions;
00156
00157 std::vector<std::size_t> unassigned;
00158 for (std::size_t i = 0; i < config_.order; ++i) {
00159 portions.push_back(static_cast<float>((i+1) * NGram<BuildingPayload>::TotalSize(i+1)));
00160 unassigned.push_back(i);
00161 }
00162
00163
00164
00165
00166 float sum;
00167 bool found_more;
00168 std::vector<std::size_t> block_count(config_.order);
00169 do {
00170 sum = 0.0;
00171 for (std::size_t i = 0; i < unassigned.size(); ++i) {
00172 sum += portions[unassigned[i]];
00173 }
00174 found_more = false;
00175
00176 for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end();) {
00177 if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) {
00178 remaining_mem -= assignments[*i];
00179 block_count[*i] = 1;
00180 i = unassigned.erase(i);
00181 found_more = true;
00182 } else {
00183 ++i;
00184 }
00185 }
00186 } while (found_more);
00187 for (std::vector<std::size_t>::iterator i = unassigned.begin(); i != unassigned.end(); ++i) {
00188 assignments[*i] = remaining_mem * (portions[*i] / sum);
00189 block_count[*i] = config_.block_count;
00190 }
00191 chains_.clear();
00192 std::cerr << "Chain sizes:";
00193 for (std::size_t i = 0; i < config_.order; ++i) {
00194
00195
00196 assignments[i] = std::max(assignments[i], block_count[i] * NGram<BuildingPayload>::TotalSize(i + 1));
00197 std::cerr << ' ' << (i+1) << ":" << assignments[i];
00198 chains_.push_back(util::stream::ChainConfig(NGram<BuildingPayload>::TotalSize(i + 1), block_count[i], assignments[i]));
00199 }
00200 std::cerr << std::endl;
00201 }
00202
00203 PipelineConfig &config_;
00204
00205 util::stream::Chains chains_;
00206
00207 util::stream::FileBuffer unigrams_;
00208
00209 const unsigned int steps_;
00210 };
00211
00212 util::stream::Sort<SuffixOrder, CombineCounts> *CountText(int text_file , int vocab_file , Master &master, uint64_t &token_count, WordIndex &type_count, std::string &text_file_name, std::vector<bool> &prune_words) {
00213 const PipelineConfig &config = master.Config();
00214 std::cerr << "=== 1/" << master.Steps() << " Counting and sorting n-grams ===" << std::endl;
00215
00216 const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate);
00217 UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory());
00218 std::size_t memory_for_chain =
00219
00220 static_cast<float>(config.TotalMemory() - vocab_usage) /
00221
00222 (static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
00223
00224 static_cast<float>(config.block_count);
00225 util::stream::Chain chain(util::stream::ChainConfig(NGram<BuildingPayload>::TotalSize(config.order), config.block_count, memory_for_chain));
00226
00227 type_count = config.vocab_estimate;
00228 util::FilePiece text(text_file, NULL, &std::cerr);
00229 text_file_name = text.FileName();
00230 CorpusCount counter(text, vocab_file, token_count, type_count, prune_words, config.prune_vocab_file, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
00231 chain >> boost::ref(counter);
00232
00233 util::scoped_ptr<util::stream::Sort<SuffixOrder, CombineCounts> > sorter(new util::stream::Sort<SuffixOrder, CombineCounts>(chain, config.sort, SuffixOrder(config.order), CombineCounts()));
00234 chain.Wait(true);
00235 return sorter.release();
00236 }
00237
00238 void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, const SpecialVocab &specials) {
00239 const PipelineConfig &config = master.Config();
00240 util::stream::Chains second(config.order);
00241
00242 {
00243 Sorts<ContextOrder> sorts;
00244 master.SetupSorts(sorts, !config.renumber_vocabulary);
00245 PrintStatistics(counts, counts_pruned, discounts);
00246 lm::ngram::ShowSizes(counts_pruned);
00247 std::cerr << "=== 3/" << master.Steps() << " Calculating and sorting initial probabilities ===" << std::endl;
00248 master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in);
00249 }
00250
00251 util::stream::Chains gamma_chains(config.order);
00252 InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds, prune_vocab, specials);
00253
00254 gamma_chains[0] >> util::stream::kRecycle;
00255 gammas.Init(config.order - 1);
00256 for (std::size_t i = 1; i < config.order; ++i) {
00257 gammas.push_back(util::MakeTemp(config.TempPrefix()));
00258 gamma_chains[i] >> gammas[i - 1].Sink();
00259 }
00260
00261 master.SetupSorts(primary, true);
00262 }
00263
00264 void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas, Output &output, const SpecialVocab &specials) {
00265 std::cerr << "=== 4/" << master.Steps() << " Calculating and writing order-interpolated probabilities ===" << std::endl;
00266 const PipelineConfig &config = master.Config();
00267 master.MaximumLazyInput(counts, primary);
00268
00269 util::stream::Chains gamma_chains(config.order - 1);
00270 for (std::size_t i = 0; i < config.order - 1; ++i) {
00271 util::stream::ChainConfig read_backoffs(config.read_backoffs);
00272
00273 if(config.prune_vocab || config.prune_thresholds[i + 1] > 0)
00274 read_backoffs.entry_size = sizeof(HashGamma);
00275 else
00276 read_backoffs.entry_size = sizeof(float);
00277
00278 gamma_chains.push_back(read_backoffs);
00279 gamma_chains.back() >> gammas[i].Source(true);
00280 }
00281 master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 ), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q, specials);
00282 gamma_chains >> util::stream::kRecycle;
00283 output.SinkProbs(master.MutableChains());
00284 }
00285
00286 class VocabNumbering {
00287 public:
00288 VocabNumbering(int final_vocab, StringPiece temp_prefix, bool renumber)
00289 : final_vocab_(final_vocab),
00290 renumber_(renumber),
00291 specials_(kBOS, kEOS) {
00292 if (renumber) {
00293 temporary_.reset(util::MakeTemp(temp_prefix));
00294 }
00295 }
00296
00297 int WriteOnTheFly() const { return renumber_ ? temporary_.get() : final_vocab_; }
00298
00299
00300 std::size_t ComputeMapping(WordIndex type_count) {
00301 if (!renumber_) return 0;
00302 ngram::SortedVocabulary::ComputeRenumbering(type_count, temporary_.get(), final_vocab_, vocab_mapping_);
00303 temporary_.reset();
00304 return sizeof(WordIndex) * vocab_mapping_.size();
00305 }
00306
00307 void ApplyRenumber(util::stream::Chains &chains) {
00308 if (!renumber_) return;
00309 for (std::size_t i = 0; i < chains.size(); ++i) {
00310 chains[i] >> Renumber(&*vocab_mapping_.begin(), i + 1);
00311 }
00312 specials_ = SpecialVocab(vocab_mapping_[specials_.BOS()], vocab_mapping_[specials_.EOS()]);
00313 }
00314
00315 const SpecialVocab &Specials() const { return specials_; }
00316
00317 private:
00318 int final_vocab_;
00319
00320 util::scoped_fd temporary_;
00321
00322 bool renumber_;
00323
00324 std::vector<WordIndex> vocab_mapping_;
00325
00326 SpecialVocab specials_;
00327 };
00328
00329 }
00330
00331 void Pipeline(PipelineConfig &config, int text_file, Output &output) {
00332
00333 if (config.sort.buffer_size * 4 > config.TotalMemory()) {
00334 config.sort.buffer_size = config.TotalMemory() / 4;
00335 std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl;
00336 }
00337 if (config.minimum_block < NGram<BuildingPayload>::TotalSize(config.order)) {
00338 config.minimum_block = NGram<BuildingPayload>::TotalSize(config.order);
00339 std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl;
00340 }
00341 UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << ".");
00342 UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception,
00343 "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size.");
00344
00345 Master master(config, output.Steps());
00346
00347
00348 try {
00349 VocabNumbering numbering(output.VocabFile(), config.TempPrefix(), config.renumber_vocabulary);
00350 uint64_t token_count;
00351 WordIndex type_count;
00352 std::string text_file_name;
00353 std::vector<bool> prune_words;
00354 util::scoped_ptr<util::stream::Sort<SuffixOrder, CombineCounts> > sorted_counts(
00355 CountText(text_file, numbering.WriteOnTheFly(), master, token_count, type_count, text_file_name, prune_words));
00356 std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
00357
00358
00359 std::size_t subtract_for_numbering = numbering.ComputeMapping(type_count);
00360
00361 std::cerr << "=== 2/" << master.Steps() << " Calculating and sorting adjusted counts ===" << std::endl;
00362 master.InitForAdjust(*sorted_counts, type_count, subtract_for_numbering);
00363 sorted_counts.reset();
00364
00365 std::vector<uint64_t> counts;
00366 std::vector<uint64_t> counts_pruned;
00367 std::vector<Discount> discounts;
00368 master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, prune_words, config.discount, discounts);
00369 numbering.ApplyRenumber(master.MutableChains());
00370
00371 {
00372 util::FixedArray<util::stream::FileBuffer> gammas;
00373 Sorts<SuffixOrder> primary;
00374 InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds, config.prune_vocab, numbering.Specials());
00375 output.SetHeader(HeaderInfo(text_file_name, token_count, counts_pruned));
00376
00377 InterpolateProbabilities(counts_pruned, master, primary, gammas, output, numbering.Specials());
00378 }
00379 } catch (const util::Exception &e) {
00380 std::cerr << e.what() << std::endl;
00381 abort();
00382 }
00383 }
00384
00385 }}