00001 #include "lm/filter/arpa_io.hh"
00002 #include "lm/filter/format.hh"
00003 #include "lm/filter/phrase.hh"
00004 #ifndef NTHREAD
00005 #include "lm/filter/thread.hh"
00006 #endif
00007 #include "lm/filter/vocab.hh"
00008 #include "lm/filter/wrapper.hh"
00009 #include "util/exception.hh"
00010 #include "util/file_piece.hh"
00011
00012 #include <boost/ptr_container/ptr_vector.hpp>
00013
00014 #include <cstring>
00015 #include <fstream>
00016 #include <iostream>
00017 #include <memory>
00018
00019 namespace lm {
00020 namespace {
00021
00022 void DisplayHelp(const char *name) {
00023 std::cerr
00024 << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n"
00025 "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n"
00026 " parser.\n"
00027 "single mode treats the entire input as a single sentence.\n"
00028 "multiple mode filters to multiple sentences in parallel. Each sentence is on\n"
00029 " a separate line. A separate file is created for each sentence by appending\n"
00030 " the 0-indexed line number to the output file name.\n"
00031 "union mode produces one filtered model that is the union of models created by\n"
00032 " multiple mode.\n\n"
00033 "context means only the context (all but last word) has to pass the filter, but\n"
00034 " the entire n-gram is output.\n\n"
00035 "phrase means that the vocabulary is actually tab-delimited phrases and that the\n"
00036 " phrases can generate the n-gram when assembled in arbitrary order and\n"
00037 " clipped. Currently works with multiple or union mode.\n\n"
00038 "The file format is set by [raw|arpa] with default arpa:\n"
00039 "raw means space-separated tokens, optionally followed by a tab and arbitrary\n"
00040 " text. This is useful for ngram count files.\n"
00041 "arpa means the ARPA file format for n-gram language models.\n\n"
00042 #ifndef NTHREAD
00043 "threads:m sets m threads (default: conccurrency detected by boost)\n"
00044 "batch_size:m sets the batch size for threading. Expect memory usage from this\n"
00045 " of 2*threads*batch_size n-grams.\n\n"
00046 #else
00047 "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n"
00048 " threading, compile without this flag against Boost >=1.42.0.\n\n"
00049 #endif
00050 "There are two inputs: vocabulary and model. Either may be given as a file\n"
00051 " while the other is on stdin. Specify the type given as a file using\n"
00052 " vocab: or model: before the file name. \n\n"
00053 "For ARPA format, the output must be seekable. For raw format, it can be a\n"
00054 " stream i.e. /dev/stdout\n";
00055 }
00056
00057 typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION, MODE_UNSET} FilterMode;
00058 typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format;
00059
00060 struct Config {
00061 Config() :
00062 #ifndef NTHREAD
00063 batch_size(25000),
00064 threads(boost::thread::hardware_concurrency()),
00065 #endif
00066 phrase(false),
00067 context(false),
00068 format(FORMAT_ARPA)
00069 {
00070 #ifndef NTHREAD
00071 if (!threads) threads = 1;
00072 #endif
00073 }
00074
00075 #ifndef NTHREAD
00076 size_t batch_size;
00077 size_t threads;
00078 #endif
00079 bool phrase;
00080 bool context;
00081 FilterMode mode;
00082 Format format;
00083 };
00084
00085 template <class Format, class Filter, class OutputBuffer, class Output> void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) {
00086 #ifndef NTHREAD
00087 if (config.threads == 1) {
00088 #endif
00089 Format::RunFilter(in_lm, filter, output);
00090 #ifndef NTHREAD
00091 } else {
00092 typedef Controller<Filter, OutputBuffer, Output> Threaded;
00093 Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output);
00094 Format::RunFilter(in_lm, threading, output);
00095 }
00096 #endif
00097 }
00098
00099 template <class Format, class Filter, class OutputBuffer, class Output> void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) {
00100 if (config.context) {
00101 ContextFilter<Filter> context_filter(filter);
00102 RunThreadedFilter<Format, ContextFilter<Filter>, OutputBuffer, Output>(config, in_lm, context_filter, output);
00103 } else {
00104 RunThreadedFilter<Format, Filter, OutputBuffer, Output>(config, in_lm, filter, output);
00105 }
00106 }
00107
00108 template <class Format, class Binary> void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) {
00109 typedef BinaryFilter<Binary> Filter;
00110 RunContextFilter<Format, Filter, BinaryOutputBuffer, typename Format::Output>(config, in_lm, Filter(binary), out);
00111 }
00112
00113 template <class Format> void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) {
00114 if (config.mode == MODE_MULTIPLE) {
00115 if (config.phrase) {
00116 typedef phrase::Multiple Filter;
00117 phrase::Substrings substrings;
00118 typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings));
00119 RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(substrings), out);
00120 } else {
00121 typedef vocab::Multiple Filter;
00122 boost::unordered_map<std::string, std::vector<unsigned int> > words;
00123 typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words));
00124 RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(words), out);
00125 }
00126 return;
00127 }
00128
00129 typename Format::Output out(out_name);
00130
00131 if (config.mode == MODE_COPY) {
00132 Format::Copy(in_lm, out);
00133 return;
00134 }
00135
00136 if (config.mode == MODE_SINGLE) {
00137 vocab::Single::Words words;
00138 vocab::ReadSingle(in_vocab, words);
00139 DispatchBinaryFilter<Format, vocab::Single>(config, in_lm, vocab::Single(words), out);
00140 return;
00141 }
00142
00143 if (config.mode == MODE_UNION) {
00144 if (config.phrase) {
00145 phrase::Substrings substrings;
00146 phrase::ReadMultiple(in_vocab, substrings);
00147 DispatchBinaryFilter<Format, phrase::Union>(config, in_lm, phrase::Union(substrings), out);
00148 } else {
00149 vocab::Union::Words words;
00150 vocab::ReadMultiple(in_vocab, words);
00151 DispatchBinaryFilter<Format, vocab::Union>(config, in_lm, vocab::Union(words), out);
00152 }
00153 return;
00154 }
00155 }
00156
00157 }
00158 }
00159
00160 int main(int argc, char *argv[]) {
00161 try {
00162 if (argc < 4) {
00163 lm::DisplayHelp(argv[0]);
00164 return 1;
00165 }
00166
00167
00168 lm::Config config;
00169 config.mode = lm::MODE_UNSET;
00170 for (int i = 1; i < argc - 2; ++i) {
00171 const char *str = argv[i];
00172 if (!std::strcmp(str, "copy")) {
00173 config.mode = lm::MODE_COPY;
00174 } else if (!std::strcmp(str, "single")) {
00175 config.mode = lm::MODE_SINGLE;
00176 } else if (!std::strcmp(str, "multiple")) {
00177 config.mode = lm::MODE_MULTIPLE;
00178 } else if (!std::strcmp(str, "union")) {
00179 config.mode = lm::MODE_UNION;
00180 } else if (!std::strcmp(str, "phrase")) {
00181 config.phrase = true;
00182 } else if (!std::strcmp(str, "context")) {
00183 config.context = true;
00184 } else if (!std::strcmp(str, "arpa")) {
00185 config.format = lm::FORMAT_ARPA;
00186 } else if (!std::strcmp(str, "raw")) {
00187 config.format = lm::FORMAT_COUNT;
00188 #ifndef NTHREAD
00189 } else if (!std::strncmp(str, "threads:", 8)) {
00190 config.threads = boost::lexical_cast<size_t>(str + 8);
00191 if (!config.threads) {
00192 std::cerr << "Specify at least one thread." << std::endl;
00193 return 1;
00194 }
00195 } else if (!std::strncmp(str, "batch_size:", 11)) {
00196 config.batch_size = boost::lexical_cast<size_t>(str + 11);
00197 if (config.batch_size < 5000) {
00198 std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
00199 if (!config.batch_size) return 1;
00200 }
00201 #endif
00202 } else {
00203 lm::DisplayHelp(argv[0]);
00204 return 1;
00205 }
00206 }
00207
00208 if (config.mode == lm::MODE_UNSET) {
00209 lm::DisplayHelp(argv[0]);
00210 return 1;
00211 }
00212
00213 if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
00214 std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
00215 return 1;
00216 }
00217
00218 bool cmd_is_model = true;
00219 const char *cmd_input = argv[argc - 2];
00220 if (!strncmp(cmd_input, "vocab:", 6)) {
00221 cmd_is_model = false;
00222 cmd_input += 6;
00223 } else if (!strncmp(cmd_input, "model:", 6)) {
00224 cmd_input += 6;
00225 } else if (strchr(cmd_input, ':')) {
00226 std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl;
00227 return 1;
00228 } else {
00229 std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
00230 }
00231 std::ifstream cmd_file;
00232 std::istream *vocab;
00233 if (cmd_is_model) {
00234 vocab = &std::cin;
00235 } else {
00236 cmd_file.open(cmd_input, std::ios::in);
00237 UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input);
00238 vocab = &cmd_file;
00239 }
00240
00241 util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
00242
00243 if (config.format == lm::FORMAT_ARPA) {
00244 lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
00245 } else if (config.format == lm::FORMAT_COUNT) {
00246 lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
00247 }
00248 return 0;
00249 } catch (const std::exception &e) {
00250 std::cerr << e.what() << std::endl;
00251 return 1;
00252 }
00253 }