00001 #include "lm/model.hh"
00002 #include "util/file_stream.hh"
00003 #include "util/file.hh"
00004 #include "util/file_piece.hh"
00005 #include "util/usage.hh"
00006
00007 #include <stdint.h>
00008
00009 namespace {
00010
00011 template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) {
00012 util::FilePiece in(fd_in);
00013 util::FileStream out(1);
00014 Width width;
00015 StringPiece word;
00016 const Width end_sentence = (Width)model.GetVocabulary().EndSentence();
00017 while (true) {
00018 while (in.ReadWordSameLine(word)) {
00019 width = (Width)model.GetVocabulary().Index(word);
00020 out.write(&width, sizeof(Width));
00021 }
00022 if (!in.ReadLineOrEOF(word)) break;
00023 out.write(&end_sentence, sizeof(Width));
00024 }
00025 }
00026
00027 template <class Model, class Width> void QueryFromBytes(const Model &model, int fd_in) {
00028 lm::ngram::State state[3];
00029 const lm::ngram::State *const begin_state = &model.BeginSentenceState();
00030 const lm::ngram::State *next_state = begin_state;
00031 Width kEOS = model.GetVocabulary().EndSentence();
00032 Width buf[4096];
00033
00034 uint64_t completed = 0;
00035 double loaded = util::CPUTime();
00036
00037 std::cout << "CPU_to_load: " << loaded << std::endl;
00038
00039
00040 double total = 0.0;
00041 while (std::size_t got = util::ReadOrEOF(fd_in, buf, sizeof(buf))) {
00042 float sum = 0.0;
00043 UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width));
00044 got /= sizeof(Width);
00045 completed += got;
00046
00047 const Width *even_end = buf + (got & ~1);
00048
00049 const Width *i;
00050 for (i = buf; i != even_end;) {
00051 sum += model.FullScore(*next_state, *i, state[1]).prob;
00052 next_state = (*i++ == kEOS) ? begin_state : &state[1];
00053 sum += model.FullScore(*next_state, *i, state[0]).prob;
00054 next_state = (*i++ == kEOS) ? begin_state : &state[0];
00055 }
00056
00057 if (got & 1) {
00058 sum += model.FullScore(*next_state, *i, state[2]).prob;
00059 next_state = (*i++ == kEOS) ? begin_state : &state[2];
00060 }
00061 total += sum;
00062 }
00063 double after = util::CPUTime();
00064 std::cerr << "Probability sum is " << total << std::endl;
00065 std::cout << "Queries: " << completed << std::endl;
00066 std::cout << "CPU_excluding_load: " << (after - loaded) << "\nCPU_per_query: " << ((after - loaded) / static_cast<double>(completed)) << std::endl;
00067 std::cout << "RSSMax: " << util::RSSMax() << std::endl;
00068 }
00069
00070 template <class Model, class Width> void DispatchFunction(const Model &model, bool query) {
00071 if (query) {
00072 QueryFromBytes<Model, Width>(model, 0);
00073 } else {
00074 ConvertToBytes<Model, Width>(model, 0);
00075 }
00076 }
00077
00078 template <class Model> void DispatchWidth(const char *file, bool query) {
00079 lm::ngram::Config config;
00080 config.load_method = util::READ;
00081 std::cerr << "Using load_method = READ." << std::endl;
00082 Model model(file, config);
00083 lm::WordIndex bound = model.GetVocabulary().Bound();
00084 if (bound <= 256) {
00085 DispatchFunction<Model, uint8_t>(model, query);
00086 } else if (bound <= 65536) {
00087 DispatchFunction<Model, uint16_t>(model, query);
00088 } else if (bound <= (1ULL << 32)) {
00089 DispatchFunction<Model, uint32_t>(model, query);
00090 } else {
00091 DispatchFunction<Model, uint64_t>(model, query);
00092 }
00093 }
00094
00095 void Dispatch(const char *file, bool query) {
00096 using namespace lm::ngram;
00097 lm::ngram::ModelType model_type;
00098 if (lm::ngram::RecognizeBinary(file, model_type)) {
00099 switch(model_type) {
00100 case PROBING:
00101 DispatchWidth<lm::ngram::ProbingModel>(file, query);
00102 break;
00103 case REST_PROBING:
00104 DispatchWidth<lm::ngram::RestProbingModel>(file, query);
00105 break;
00106 case TRIE:
00107 DispatchWidth<lm::ngram::TrieModel>(file, query);
00108 break;
00109 case QUANT_TRIE:
00110 DispatchWidth<lm::ngram::QuantTrieModel>(file, query);
00111 break;
00112 case ARRAY_TRIE:
00113 DispatchWidth<lm::ngram::ArrayTrieModel>(file, query);
00114 break;
00115 case QUANT_ARRAY_TRIE:
00116 DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, query);
00117 break;
00118 default:
00119 UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type);
00120 }
00121 } else {
00122 UTIL_THROW(util::Exception, "Binarize before running benchmarks.");
00123 }
00124 }
00125
00126 }
00127
00128 int main(int argc, char *argv[]) {
00129 if (argc != 3 || (strcmp(argv[1], "vocab") && strcmp(argv[1], "query"))) {
00130 std::cerr
00131 << "Benchmark program for KenLM. Intended usage:\n"
00132 << "#Convert text to vocabulary ids offline. These ids are tied to a model.\n"
00133 << argv[0] << " vocab $model <$text >$text.vocab\n"
00134 << "#Ensure files are in RAM.\n"
00135 << "cat $text.vocab $model >/dev/null\n"
00136 << "#Timed query against the model.\n"
00137 << argv[0] << " query $model <$text.vocab\n";
00138 return 1;
00139 }
00140 Dispatch(argv[2], !strcmp(argv[1], "query"));
00141 return 0;
00142 }