00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 using namespace std;
00022
00023 #include <iostream>
00024 #include <fstream>
00025 #include <sstream>
00026 #include <stdexcept>
00027 #include <vector>
00028 #include <string>
00029 #include <stdlib.h>
00030 #include "cmd.h"
00031 #include "util.h"
00032 #include "math.h"
00033 #include "lmContainer.h"
00034
00035
00036
00037 inline void error(const char* message)
00038 {
00039 std::cerr << message << "\n";
00040 throw std::runtime_error(message);
00041 }
00042
00043 lmContainer* load_lm(std::string file,int requiredMaxlev,int dub,int memmap, float nlf, float dlf);
00044
00045 void print_help(int TypeFlag=0){
00046 std::cerr << std::endl << "interpolate-lm - interpolates language models" << std::endl;
00047 std::cerr << std::endl << "USAGE:" << std::endl;
00048 std::cerr << " interpolate-lm [options] <lm-list-file> [lm-list-file.out]" << std::endl;
00049
00050 std::cerr << std::endl << "DESCRIPTION:" << std::endl;
00051 std::cerr << " interpolate-lm reads a LM list file including interpolation weights " << std::endl;
00052 std::cerr << " with the format: N\\n w1 lm1 \\n w2 lm2 ...\\n wN lmN\n" << std::endl;
00053 std::cerr << " It estimates new weights on a development text, " << std::endl;
00054 std::cerr << " computes the perplexity on an evaluation text, " << std::endl;
00055 std::cerr << " computes probabilities of n-grams read from stdin." << std::endl;
00056 std::cerr << " It reads LMs in ARPA and IRSTLM binary format." << std::endl;
00057
00058 std::cerr << std::endl << "OPTIONS:" << std::endl;
00059 FullPrintParams(TypeFlag, 0, 1, stderr);
00060
00061 }
00062
00063 void usage(const char *msg = 0)
00064 {
00065 if (msg){
00066 std::cerr << msg << std::endl;
00067 }
00068 else{
00069 print_help();
00070 }
00071 exit(1);
00072 }
00073
00074 int main(int argc, char **argv)
00075 {
00076 char *slearn = NULL;
00077 char *seval = NULL;
00078 bool learn=false;
00079 bool score=false;
00080 bool sent_PP_flag = false;
00081
00082 int order = 0;
00083 int debug = 0;
00084 int memmap = 0;
00085 int requiredMaxlev = 1000;
00086 int dub = 10000000;
00087 float ngramcache_load_factor = 0.0;
00088 float dictionary_load_factor = 0.0;
00089
00090 bool help=false;
00091 std::vector<std::string> files;
00092
00093 DeclareParams((char*)
00094
00095 "learn", CMDSTRINGTYPE|CMDMSG, &slearn, "learn optimal interpolation for text-file; default is false",
00096 "l", CMDSTRINGTYPE|CMDMSG, &slearn, "learn optimal interpolation for text-file; default is false",
00097 "order", CMDINTTYPE|CMDMSG, &order, "order of n-grams used in --learn (optional)",
00098 "o", CMDINTTYPE|CMDMSG, &order, "order of n-grams used in --learn (optional)",
00099 "eval", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
00100 "e", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
00101
00102 "DictionaryUpperBound", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
00103 "dub", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
00104 "score", CMDBOOLTYPE|CMDMSG, &score, "computes log-prob scores of n-grams from standard input",
00105 "s", CMDBOOLTYPE|CMDMSG, &score, "computes log-prob scores of n-grams from standard input",
00106
00107 "debug", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
00108 "d", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
00109 "memmap", CMDINTTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
00110 "mm", CMDINTTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
00111 "sentence", CMDBOOLTYPE|CMDMSG, &sent_PP_flag, "computes perplexity at sentence level (identified through the end symbol)",
00112 "dict_load_factor", CMDFLOATTYPE|CMDMSG, &dictionary_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is 0",
00113 "ngram_load_factor", CMDFLOATTYPE|CMDMSG, &ngramcache_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is false",
00114 "level", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
00115 "lev", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
00116
00117 "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
00118 "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
00119
00120 (char *)NULL
00121 );
00122
00123 if (argc == 1){
00124 usage();
00125 }
00126
00127 for(int i=1; i < argc; i++) {
00128 if(argv[i][0] != '-') files.push_back(argv[i]);
00129 }
00130
00131 GetParams(&argc, &argv, (char*) NULL);
00132
00133 if (help){
00134 usage();
00135 }
00136
00137 if (files.size() > 2) {
00138 usage("Warning: Too many arguments");
00139 }
00140
00141 if (files.size() < 1) {
00142 usage("Warning: specify a LM list file to read from");
00143 }
00144
00145 std::string infile = files[0];
00146 std::string outfile="";
00147
00148 if (files.size() == 1) {
00149 outfile=infile;
00150
00151 std::string::size_type p = outfile.rfind('/');
00152 if (p != std::string::npos && ((p+1) < outfile.size()))
00153 outfile.erase(0,p+1);
00154 outfile+=".out";
00155 } else
00156 outfile = files[1];
00157
00158 std::cerr << "inpfile: " << infile << std::endl;
00159 learn = ((slearn != NULL)? true : false);
00160
00161 if (learn) std::cerr << "outfile: " << outfile << std::endl;
00162 if (score) std::cerr << "interactive: " << score << std::endl;
00163 if (memmap) std::cerr << "memory mapping: " << memmap << std::endl;
00164 std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
00165 std::cerr << "order: " << order << std::endl;
00166 if (requiredMaxlev > 0) std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
00167
00168 std::cerr << "dub: " << dub<< std::endl;
00169
00170 lmContainer *lmt[100], *start_lmt[100];
00171 std::string lmf[100];
00172
00173 float w[100];
00174 int N;
00175
00176
00177
00178 std::cerr << "Reading " << infile << "..." << std::endl;
00179 std::fstream inptxt(infile.c_str(),std::ios::in);
00180
00181
00182 char line[BUFSIZ];
00183 const char* words[3];
00184 int tokenN;
00185
00186 inptxt.getline(line,BUFSIZ,'\n');
00187 tokenN = parseWords(line,words,3);
00188
00189 if (tokenN != 2 || ((strcmp(words[0],"LMINTERPOLATION") != 0) && (strcmp(words[0],"lminterpolation")!=0)))
00190 error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMINTERPOLATION number_of_models\nweight_of_LM_1 filename_of_LM_1\nweight_of_LM_2 filename_of_LM_2");
00191
00192 N=atoi(words[1]);
00193 std::cerr << "Number of LMs: " << N << "..." << std::endl;
00194 if(N > 100) {
00195 std::cerr << "Can't interpolate more than 100 language models." << std::endl;
00196 exit(1);
00197 }
00198
00199 for (int i=0; i<N; i++) {
00200 inptxt.getline(line,BUFSIZ,'\n');
00201 tokenN = parseWords(line,words,3);
00202 if(tokenN != 2) {
00203 std::cerr << "Wrong input format." << std::endl;
00204 exit(1);
00205 }
00206 w[i] = (float) atof(words[0]);
00207 lmf[i] = words[1];
00208
00209 std::cerr << "i:" << i << " w[i]:" << w[i] << " lmf[i]:" << lmf[i] << std::endl;
00210 start_lmt[i] = lmt[i] = load_lm(lmf[i],requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
00211 }
00212
00213 inptxt.close();
00214
00215 int maxorder = 0;
00216 for (int i=0; i<N; i++) {
00217 maxorder = (maxorder > lmt[i]->maxlevel())?maxorder:lmt[i]->maxlevel();
00218 }
00219
00220 if (order <= 0) {
00221 order = maxorder;
00222 std::cerr << "order is not set or wrongly set to a non positive value; reset to the maximum order of LMs: " << order << std::endl;
00223 } else if (order > maxorder) {
00224 order = maxorder;
00225 std::cerr << "order is too high; reset to the maximum order of LMs" << order << std::endl;
00226 }
00227
00228
00229 if (learn) {
00230
00231 std::vector<float> p[N];
00232 float c[N];
00233 float den,norm;
00234 float variation=1.0;
00235
00236 dictionary* dict=new dictionary(slearn,1000000,dictionary_load_factor);
00237 ngram ng(dict);
00238 int bos=ng.dict->encode(ng.dict->BoS());
00239 std::ifstream dev(slearn,std::ios::in);
00240
00241 for(;;) {
00242 std::string line;
00243 getline(dev, line);
00244 if(dev.eof())
00245 break;
00246 if(dev.fail()) {
00247 std::cerr << "Problem reading input file " << seval << std::endl;
00248 exit(1);
00249 }
00250 std::istringstream lstream(line);
00251 if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
00252 std::string token, newlm;
00253 int id;
00254 lstream >> token >> id >> newlm;
00255 if(id <= 0 || id > N) {
00256 std::cerr << "LM id out of range." << std::endl;
00257 return 1;
00258 }
00259 id--;
00260 if(lmt[id] != start_lmt[id])
00261 delete lmt[id];
00262 lmt[id] = load_lm(newlm,requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
00263 continue;
00264 }
00265 while(lstream >> ng) {
00266
00267
00268 if (*ng.wordp(1)==bos) {
00269 ng.size=1;
00270 continue;
00271 }
00272 if (order > 0 && ng.size > order) ng.size=order;
00273 for (int i=0; i<N; i++) {
00274 ngram ong(lmt[i]->getDict());
00275 ong.trans(ng);
00276 double logpr;
00277 logpr = lmt[i]->clprob(ong);
00278 p[i].push_back(pow(10.0,logpr));
00279 }
00280 }
00281
00282 for (int i=0; i<N; i++) lmt[i]->check_caches_levels();
00283 }
00284 dev.close();
00285
00286 while( variation > 0.01 ) {
00287
00288 for (int i=0; i<N; i++) c[i]=0;
00289
00290 for(unsigned i = 0; i < p[0].size(); i++) {
00291 den=0.0;
00292 for(int j = 0; j < N; j++)
00293 den += w[j] * p[j][i];
00294
00295 for(int j = 0; j < N; j++)
00296 c[j] += w[j] * p[j][i] / den;
00297 }
00298
00299 norm=0.0;
00300 for (int i=0; i<N; i++) norm+=c[i];
00301
00302
00303 variation=0.0;
00304 for (int i=0; i<N; i++) {
00305 c[i]/=norm;
00306 variation+=(w[i]>c[i]?(w[i]-c[i]):(c[i]-w[i]));
00307 w[i]=c[i];
00308 }
00309 std::cerr << "Variation " << variation << std::endl;
00310 }
00311
00312
00313 std::cerr << "Saving in " << outfile << "..." << std::endl;
00314
00315 std::fstream outtxt(outfile.c_str(),std::ios::out);
00316 outtxt << "LMINTERPOLATION " << N << "\n";
00317 for (int i=0; i<N; i++) outtxt << w[i] << " " << lmf[i] << "\n";
00318 outtxt.close();
00319 }
00320
00321 for(int i = 0; i < N; i++)
00322 if(lmt[i] != start_lmt[i]) {
00323 delete lmt[i];
00324 lmt[i] = start_lmt[i];
00325 }
00326
00327 if (seval != NULL) {
00328 std::cerr << "Start Eval" << std::endl;
00329
00330 std::cout.setf(ios::fixed);
00331 std::cout.precision(2);
00332 int i;
00333 int Nw=0,Noov_all=0, Noov_any=0, Nbo=0;
00334 double Pr,lPr;
00335 double logPr=0,PP=0;
00336
00337
00338 int sent_Nw=0, sent_Noov_all=0, sent_Noov_any=0, sent_Nbo=0;
00339 double sent_logPr=0,sent_PP=0;
00340
00341
00342 for (i=0,Pr=0; i<N; i++) Pr+=w[i];
00343 for (i=0; i<N; i++) w[i]/=Pr;
00344
00345 dictionary* dict=new dictionary(NULL,1000000,dictionary_load_factor);
00346 dict->incflag(1);
00347 ngram ng(dict);
00348 int bos=ng.dict->encode(ng.dict->BoS());
00349 int eos=ng.dict->encode(ng.dict->EoS());
00350
00351 std::fstream inptxt(seval,std::ios::in);
00352
00353 for(;;) {
00354 std::string line;
00355 getline(inptxt, line);
00356 if(inptxt.eof())
00357 break;
00358 if(inptxt.fail()) {
00359 std::cerr << "Problem reading input file " << seval << std::endl;
00360 return 1;
00361 }
00362 std::istringstream lstream(line);
00363 if(line.substr(0, 26) == "###interpolate-lm:weights ") {
00364 std::string token;
00365 lstream >> token;
00366 for(int i = 0; i < N; i++) {
00367 if(lstream.eof()) {
00368 std::cerr << "Not enough weights!" << std::endl;
00369 return 1;
00370 }
00371 lstream >> w[i];
00372 }
00373 continue;
00374 }
00375 if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
00376 std::string token, newlm;
00377 int id;
00378 lstream >> token >> id >> newlm;
00379 if(id <= 0 || id > N) {
00380 std::cerr << "LM id out of range." << std::endl;
00381 return 1;
00382 }
00383 id--;
00384 delete lmt[id];
00385 lmt[id] = load_lm(newlm,requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
00386 continue;
00387 }
00388
00389 double bow;
00390 int bol=0;
00391 char *msp;
00392 unsigned int statesize;
00393
00394 while(lstream >> ng) {
00395
00396
00397 if (*ng.wordp(1)==bos) {
00398 ng.size=1;
00399 continue;
00400 }
00401 if (order > 0 && ng.size > order) ng.size=order;
00402
00403
00404 if (ng.size>=1) {
00405
00406 int minbol=MAX_NGRAM;
00407 bool OOV_all_flag=true;
00408 bool OOV_any_flag=false;
00409 float logpr;
00410
00411 Pr = 0.0;
00412 for (i=0; i<N; i++) {
00413
00414 ngram ong(lmt[i]->getDict());
00415 ong.trans(ng);
00416 logpr = lmt[i]->clprob(ong,&bow,&bol,&msp,&statesize);
00417
00418
00419 Pr+=w[i] * pow(10.0,logpr);
00420 if (bol < minbol) minbol=bol;
00421
00422 if (*ong.wordp(1) != lmt[i]->getDict()->oovcode()) OOV_all_flag=false;
00423 if (*ong.wordp(1) == lmt[i]->getDict()->oovcode()) OOV_any_flag=true;
00424 }
00425
00426 lPr=log(Pr)/M_LN10;
00427 logPr+=lPr;
00428 sent_logPr+=lPr;
00429
00430 if (debug==1) {
00431 std::cout << ng.dict->decode(*ng.wordp(1)) << " [" << ng.size-minbol << "]" << " ";
00432 if (*ng.wordp(1)==eos) std::cout << std::endl;
00433 }
00434 if (debug==2)
00435 std::cout << ng << " [" << ng.size-minbol << "-gram]" << " " << log(Pr) << std::endl;
00436
00437 if (minbol) {
00438 Nbo++;
00439 sent_Nbo++;
00440 }
00441
00442 if (OOV_all_flag) {
00443 Noov_all++;
00444 sent_Noov_all++;
00445 }
00446 if (OOV_any_flag) {
00447 Noov_any++;
00448 sent_Noov_any++;
00449 }
00450
00451 Nw++;
00452 sent_Nw++;
00453
00454 if (*ng.wordp(1)==eos && sent_PP_flag) {
00455 sent_PP=exp((-sent_logPr * log(10.0)) /sent_Nw);
00456 std::cout << "%% sent_Nw=" << sent_Nw
00457 << " sent_PP=" << sent_PP
00458 << " sent_Nbo=" << sent_Nbo
00459 << " sent_Noov=" << sent_Noov_all
00460 << " sent_OOV=" << (float)sent_Noov_all/sent_Nw * 100.0 << "%"
00461 << " sent_Noov_any=" << sent_Noov_any
00462 << " sent_OOV_any=" << (float)sent_Noov_any/sent_Nw * 100.0 << "%" << std::endl;
00463
00464 sent_Nw=sent_Noov_any=sent_Noov_all=sent_Nbo=0;
00465 sent_logPr=0.0;
00466 }
00467
00468
00469 if ((Nw % 10000)==0) std::cerr << ".";
00470 }
00471 }
00472 }
00473
00474 PP=exp((-logPr * M_LN10) /Nw);
00475
00476 std::cout << "%% Nw=" << Nw
00477 << " PP=" << PP
00478 << " Nbo=" << Nbo
00479 << " Noov=" << Noov_all
00480 << " OOV=" << (float)Noov_all/Nw * 100.0 << "%"
00481 << " Noov_any=" << Noov_any
00482 << " OOV_any=" << (float)Noov_any/Nw * 100.0 << "%" << std::endl;
00483
00484 };
00485
00486
00487 if (score == true) {
00488
00489
00490 dictionary* dict=new dictionary(NULL,1000000,dictionary_load_factor);
00491 dict->incflag(1);
00492 ngram ng(dict);
00493 int bos=ng.dict->encode(ng.dict->BoS());
00494
00495 double Pr,logpr;
00496
00497 double bow;
00498 int bol=0, maxbol=0;
00499 unsigned int maxstatesize, statesize;
00500 int i,n=0;
00501 std::cout << "> ";
00502 while(std::cin >> ng) {
00503
00504
00505 if (*ng.wordp(1)==bos) {
00506 ng.size=1;
00507 continue;
00508 }
00509
00510 if (ng.size>=maxorder) {
00511
00512 if (order > 0 && ng.size > order) ng.size=order;
00513 n++;
00514 maxstatesize=0;
00515 maxbol=0;
00516 Pr=0.0;
00517 for (i=0; i<N; i++) {
00518 ngram ong(lmt[i]->getDict());
00519 ong.trans(ng);
00520 logpr = lmt[i]->clprob(ong,&bow,&bol,NULL,&statesize);
00521
00522 Pr+=w[i] * pow(10.0,logpr);
00523 std::cout << "lm " << i << ":" << " logpr: " << logpr << " weight: " << w[i] << std::endl;
00524 if (maxbol<bol) maxbol=bol;
00525 if (maxstatesize<statesize) maxstatesize=statesize;
00526 }
00527
00528 std::cout << ng << " p= " << log(Pr) << " bo= " << maxbol << " recombine= " << maxstatesize << std::endl;
00529
00530 if ((n % 10000000)==0) {
00531 std::cerr << "." << std::endl;
00532 for (i=0; i<N; i++) lmt[i]->check_caches_levels();
00533 }
00534
00535 } else {
00536 std::cout << ng << " p= NULL" << std::endl;
00537 }
00538 std::cout << "> ";
00539 }
00540
00541
00542 }
00543
00544 for (int i=0; i<N; i++) delete lmt[i];
00545
00546 return 0;
00547 }
00548
00549 lmContainer* load_lm(std::string file,int requiredMaxlev,int dub,int memmap, float nlf, float dlf)
00550 {
00551 lmContainer* lmt=NULL;
00552
00553 lmt = lmt->CreateLanguageModel(file,nlf,dlf);
00554
00555 lmt->setMaxLoadedLevel(requiredMaxlev);
00556
00557 lmt->load(file,memmap);
00558
00559 if (dub) lmt->setlogOOVpenalty((int)dub);
00560
00561
00562 lmt->init_caches(lmt->maxlevel());
00563 return lmt;
00564 }