00001 #include "lm/builder/output.hh"
00002 #include "lm/builder/pipeline.hh"
00003 #include "lm/common/size_option.hh"
00004 #include "lm/lm_exception.hh"
00005 #include "util/file.hh"
00006 #include "util/file_piece.hh"
00007 #include "util/usage.hh"
00008
00009 #include <iostream>
00010
00011 #include <boost/program_options.hpp>
00012 #include <boost/version.hpp>
00013 #include <vector>
00014
00015 namespace {
00016
00017
00018
00019 std::vector<uint64_t> ParsePruning(const std::vector<std::string> ¶m, std::size_t order) {
00020
00021 std::vector<uint64_t> prune_thresholds;
00022 prune_thresholds.reserve(order);
00023 for (std::vector<std::string>::const_iterator it(param.begin()); it != param.end(); ++it) {
00024 try {
00025 prune_thresholds.push_back(boost::lexical_cast<uint64_t>(*it));
00026 } catch(const boost::bad_lexical_cast &) {
00027 UTIL_THROW(util::Exception, "Bad pruning threshold " << *it);
00028 }
00029 }
00030
00031
00032 if (prune_thresholds.empty()) {
00033 prune_thresholds.resize(order, 0);
00034 return prune_thresholds;
00035 }
00036
00037
00038
00039 UTIL_THROW_IF(prune_thresholds.size() > order, util::Exception, "You specified pruning thresholds for orders 1 through " << prune_thresholds.size() << " but the model only has order " << order);
00040
00041
00042
00043 uint64_t lower_threshold = 0;
00044 for (std::vector<uint64_t>::iterator it = prune_thresholds.begin(); it != prune_thresholds.end(); ++it) {
00045 UTIL_THROW_IF(lower_threshold > *it, util::Exception, "Pruning thresholds should be in non-decreasing order. Otherwise substrings would be removed, which is bad for query-time data structures.");
00046 lower_threshold = *it;
00047 }
00048
00049
00050 prune_thresholds.resize(order, prune_thresholds.back());
00051 return prune_thresholds;
00052 }
00053
00054 lm::builder::Discount ParseDiscountFallback(const std::vector<std::string> ¶m) {
00055 lm::builder::Discount ret;
00056 UTIL_THROW_IF(param.size() > 3, util::Exception, "Specify at most three fallback discounts: 1, 2, and 3+");
00057 UTIL_THROW_IF(param.empty(), util::Exception, "Fallback discounting enabled, but no discount specified");
00058 ret.amount[0] = 0.0;
00059 for (unsigned i = 0; i < 3; ++i) {
00060 float discount = boost::lexical_cast<float>(param[i < param.size() ? i : (param.size() - 1)]);
00061 UTIL_THROW_IF(discount < 0.0 || discount > static_cast<float>(i+1), util::Exception, "The discount for count " << (i+1) << " was parsed as " << discount << " which is not in the range [0, " << (i+1) << "].");
00062 ret.amount[i + 1] = discount;
00063 }
00064 return ret;
00065 }
00066
00067 }
00068
00069 int main(int argc, char *argv[]) {
00070 try {
00071 namespace po = boost::program_options;
00072 po::options_description options("Language model building options");
00073 lm::builder::PipelineConfig pipeline;
00074
00075 std::string text, intermediate, arpa;
00076 std::vector<std::string> pruning;
00077 std::vector<std::string> discount_fallback;
00078 std::vector<std::string> discount_fallback_default;
00079 discount_fallback_default.push_back("0.5");
00080 discount_fallback_default.push_back("1");
00081 discount_fallback_default.push_back("1.5");
00082 bool verbose_header;
00083
00084 options.add_options()
00085 ("help,h", po::bool_switch(), "Show this help message")
00086 ("order,o", po::value<std::size_t>(&pipeline.order)
00087 #if BOOST_VERSION >= 104200
00088 ->required()
00089 #endif
00090 , "Order of the model")
00091 ("interpolate_unigrams", po::value<bool>(&pipeline.initial_probs.interpolate_unigrams)->default_value(true)->implicit_value(true), "Interpolate the unigrams (default) as opposed to giving lots of mass to <unk> like SRI. If you want SRI's behavior with a large <unk> and the old lmplz default, use --interpolate_unigrams 0.")
00092 ("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception")
00093 ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
00094 ("memory,S", lm:: SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
00095 ("minimum_block", lm::SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
00096 ("sort_block", lm::SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
00097 ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
00098 ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
00099 ("vocab_pad", po::value<uint64_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams")
00100 ("verbose_header", po::bool_switch(&verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
00101 ("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
00102 ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
00103 ("intermediate", po::value<std::string>(&intermediate), "Write ngrams to intermediate files. Turns off ARPA output (which can be reactivated by --arpa file). Forces --renumber on.")
00104 ("renumber", po::bool_switch(&pipeline.renumber_vocabulary), "Rrenumber the vocabulary identifiers so that they are monotone with the hash of each string. This is consistent with the ordering used by the trie data structure.")
00105 ("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.")
00106 ("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Default is to not prune, which is equivalent to --prune 0.")
00107 ("limit_vocab_file", po::value<std::string>(&pipeline.prune_vocab_file)->default_value(""), "Read allowed vocabulary separated by whitespace. N-grams that contain vocabulary items not in this list will be pruned. Can be combined with --prune arg")
00108 ("discount_fallback", po::value<std::vector<std::string> >(&discount_fallback)->multitoken()->implicit_value(discount_fallback_default, "0.5 1 1.5"), "The closed-form estimate for Kneser-Ney discounts does not work without singletons or doubletons. It can also fail if these values are out of range. This option falls back to user-specified discounts when the closed-form estimate fails. Note that this option is generally a bad idea: you should deduplicate your corpus instead. However, class-based models need custom discounts because they lack singleton unigrams. Provide up to three discounts (for adjusted counts 1, 2, and 3+), which will be applied to all orders where the closed-form estimates fail.");
00109 po::variables_map vm;
00110 po::store(po::parse_command_line(argc, argv, options), vm);
00111
00112 if (argc == 1 || vm["help"].as<bool>()) {
00113 std::cerr <<
00114 "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
00115 "Please cite:\n"
00116 "@inproceedings{Heafield-estimate,\n"
00117 " author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n"
00118 " title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n"
00119 " year = {2013},\n"
00120 " month = {8},\n"
00121 " booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n"
00122 " address = {Sofia, Bulgaria},\n"
00123 " url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n"
00124 "}\n\n"
00125 "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n"
00126 "the model (-o) is the only mandatory option. As this is an on-disk program,\n"
00127 "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"
00128 "Memory sizes are specified like GNU sort: a number followed by a unit character.\n"
00129 "Valid units are \% for percentage of memory (supported platforms only) and (in\n"
00130 "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n";
00131 uint64_t mem = util::GuessPhysicalMemory();
00132 if (mem) {
00133 std::cerr << "This machine has " << mem << " bytes of memory.\n\n";
00134 } else {
00135 std::cerr << "Unable to determine the amount of memory on this machine.\n\n";
00136 }
00137 std::cerr << options << std::endl;
00138 return 1;
00139 }
00140
00141 po::notify(vm);
00142
00143
00144 #if BOOST_VERSION < 104200
00145 if (!vm.count("order")) {
00146 std::cerr << "the option '--order' is required but missing" << std::endl;
00147 return 1;
00148 }
00149 #endif
00150
00151 if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) {
00152 std::cerr << "--vocab_pad requires --interpolate_unigrams be on" << std::endl;
00153 return 1;
00154 }
00155
00156 if (vm["skip_symbols"].as<bool>()) {
00157 pipeline.disallowed_symbol_action = lm::COMPLAIN;
00158 } else {
00159 pipeline.disallowed_symbol_action = lm::THROW_UP;
00160 }
00161
00162 if (vm.count("discount_fallback")) {
00163 pipeline.discount.fallback = ParseDiscountFallback(discount_fallback);
00164 pipeline.discount.bad_action = lm::COMPLAIN;
00165 } else {
00166
00167 pipeline.discount.fallback = lm::builder::Discount();
00168 pipeline.discount.bad_action = lm::THROW_UP;
00169 }
00170
00171
00172 pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order);
00173
00174 if (!vm["limit_vocab_file"].as<std::string>().empty()) {
00175 pipeline.prune_vocab = true;
00176 }
00177 else {
00178 pipeline.prune_vocab = false;
00179 }
00180
00181 util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
00182
00183 lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
00184
00185 initial.adder_in.total_memory = 32768;
00186 initial.adder_in.block_count = 2;
00187 initial.adder_out.total_memory = 32768;
00188 initial.adder_out.block_count = 2;
00189 pipeline.read_backoffs = initial.adder_out;
00190
00191
00192 util::scoped_fd in(0), out(1);
00193 if (vm.count("text")) {
00194 in.reset(util::OpenReadOrThrow(text.c_str()));
00195 }
00196 if (vm.count("arpa")) {
00197 out.reset(util::CreateOrThrow(arpa.c_str()));
00198 }
00199
00200 try {
00201 bool writing_intermediate = vm.count("intermediate");
00202 if (writing_intermediate) {
00203 pipeline.renumber_vocabulary = true;
00204 }
00205 lm::builder::Output output(writing_intermediate ? intermediate : pipeline.sort.temp_prefix, writing_intermediate, pipeline.output_q);
00206 if (!writing_intermediate || vm.count("arpa")) {
00207 output.Add(new lm::builder::PrintHook(out.release(), verbose_header));
00208 }
00209 lm::builder::Pipeline(pipeline, in.release(), output);
00210 } catch (const util::MallocException &e) {
00211 std::cerr << e.what() << std::endl;
00212 std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl;
00213 return 1;
00214 }
00215 util::PrintUsage(std::cerr);
00216 } catch (const std::exception &e) {
00217 std::cerr << e.what() << std::endl;
00218 return 1;
00219 }
00220 }