00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 using namespace std;
00022
00023 #include <cmath>
00024 #include "mfstream.h"
00025 #include "mempool.h"
00026 #include "htable.h"
00027 #include "dictionary.h"
00028 #include "n_gram.h"
00029 #include "mempool.h"
00030 #include "ngramcache.h"
00031 #include "ngramtable.h"
00032 #include "normcache.h"
00033 #include "interplm.h"
00034
00035 void interplm::trainunigr()
00036 {
00037
00038 int oov=dict->getcode(dict->OOV());
00039 cerr << "oovcode: " << oov << "\n";
00040
00041 if (oov>=0 && dict->freq(oov)>= dict->size()) {
00042 cerr << "Using current estimate of OOV frequency " << dict->freq(oov)<< "\n";
00043 } else {
00044 oov=dict->encode(dict->OOV());
00045 dict->oovcode(oov);
00046
00047
00048
00049
00050
00051
00052 if (unismooth) {
00053 dict->incfreq(oov,dict->size()-1);
00054 cerr << "Witten-Bell estimate of OOV freq:"<< (double)(dict->size()-1)/dict->totfreq() << "\n";
00055 } else {
00056 if (dict->dub()) {
00057 cerr << "DUB estimate of OOV size\n";
00058 dict->incfreq(oov,dict->dub()-dict->size()+1);
00059 } else {
00060 cerr << "1 = estimate of OOV size\n";
00061 dict->incfreq(oov,1);
00062 }
00063 }
00064 }
00065 }
00066
00067
00068 double interplm::unigr(ngram ng)
00069 {
00070
00071 return
00072 ((double)(dict->freq(*ng.wordp(1))+epsilon))/
00073 ((double)dict->totfreq() + (double) dict->size() * epsilon);
00074
00075 }
00076
00077
00078 interplm::interplm(char *ngtfile,int depth,TABLETYPE tabtype):
00079 ngramtable(ngtfile,depth,NULL,NULL,NULL,0,0,NULL,0,tabtype)
00080 {
00081
00082 if (maxlevel()<depth) {
00083 cerr << "interplm: ngramtable size is too low\n";
00084 exit(1);
00085 }
00086
00087 lms=depth;
00088 unitbl=NULL;
00089 epsilon=1.0;
00090 unismooth=1;
00091 prune_singletons=0;
00092 prune_top_singletons=0;
00093
00094
00095
00096 int BoS=dict->encode(dict->BoS());
00097 if (BoS != dict->oovcode()) {
00098 cerr << "setting counter of Begin of Sentence to 1 ..." << "\n";
00099 dict->freq(BoS,1);
00100 cerr << "start_sent: " << (char *)dict->decode(BoS) << " "
00101 << dict->freq(BoS) << "\n";
00102 }
00103
00104 };
00105
00106
00107 void interplm::gensuccstat()
00108 {
00109
00110 ngram hg(dict);
00111 int s1,s2;
00112
00113 cerr << "Generating successor statistics\n";
00114
00115
00116 for (int l=2; l<=lms; l++) {
00117
00118 cerr << "level " << l << "\n";
00119
00120 scan(hg,INIT,l-1);
00121 while(scan(hg,CONT,l-1)) {
00122
00123 s1=s2=0;
00124
00125 ngram ng=hg;
00126 ng.pushc(0);
00127
00128 succscan(hg,ng,INIT,l);
00129 while(succscan(hg,ng,CONT,l)) {
00130
00131
00132 if (corrcounts && l<lms)
00133 ng.freq=getfreq(ng.link,ng.pinfo,1);
00134
00135 if (ng.freq==1) s1++;
00136 else if (ng.freq==2) s2++;
00137 }
00138
00139 succ2(hg.link,s2);
00140 succ1(hg.link,s1);
00141 }
00142 }
00143 }
00144
00145
00146 void interplm::gencorrcounts()
00147 {
00148
00149 cerr << "Generating corrected n-gram tables\n";
00150
00151 for (int l=lms-1; l>=1; l--) {
00152
00153 cerr << "level " << l << "\n";
00154
00155 ngram ng(dict);
00156 int count=0;
00157
00158
00159 scan(ng,INIT,l+1);
00160 while(scan(ng,CONT,l+1)) {
00161
00162 ngram ng2=ng;
00163 ng2.size--;
00164 if (get(ng2,ng2.size,ng2.size)) {
00165
00166 if (!ng2.containsWord(dict->BoS(),1))
00167
00168 setfreq(ng2.link,ng2.pinfo,1+getfreq(ng2.link,ng2.pinfo,1),1);
00169 else
00170
00171
00172 setfreq(ng2.link,ng2.pinfo,ng2.freq,1);
00173 } else {
00174 assert(lms==l+1);
00175 cerr << "cannot find2 " << ng2 << "count " << count << "\n";
00176 cerr << "inserting ngram and starting from scratch\n";
00177 ng2.pushw(dict->BoS());
00178 ng2.freq=100;
00179 put(ng2);
00180
00181 cerr << "reset all counts at last level\n";
00182
00183 scan(ng2,INIT,lms-1);
00184 while(scan(ng2,CONT,lms-1)) {
00185 setfreq(ng2.link,ng2.pinfo,0,1);
00186 }
00187
00188 gencorrcounts();
00189 return;
00190 }
00191 }
00192 }
00193
00194 cerr << "Updating history counts\n";
00195
00196 for (int l=lms-2; l>=1; l--) {
00197
00198 cerr << "level " << l << "\n";
00199
00200 cerr << "reset counts\n";
00201
00202 ngram ng(dict);
00203 scan(ng,INIT,l);
00204 while(scan(ng,CONT,l)) {
00205 freq(ng.link,ng.pinfo,0);
00206 }
00207
00208 scan(ng,INIT,l+1);
00209 while(scan(ng,CONT,l+1)) {
00210
00211 ngram ng2=ng;
00212 get(ng2,l+1,l);
00213 freq(ng2.link,ng2.pinfo,freq(ng2.link,ng2.pinfo)+getfreq(ng.link,ng.pinfo,1));
00214 }
00215 }
00216
00217 cerr << "Adding unigram of OOV word if missing\n";
00218 ngram ng(dict,maxlevel());
00219 for (int i=1; i<=maxlevel(); i++)
00220 *ng.wordp(i)=dict->oovcode();
00221
00222 if (!get(ng,lms,1)) {
00223
00224
00225 ng.freq=dict->size();
00226 cerr << "adding oov unigram " << ng << "\n";
00227 put(ng);
00228 get(ng,lms,1);
00229 setfreq(ng.link,ng.pinfo,ng.freq,1);
00230 }
00231
00232 cerr << "Replacing unigram of BoS \n";
00233 if (dict->encode(dict->BoS()) != dict->oovcode()) {
00234 ngram ng(dict,1);
00235 *ng.wordp(1)=dict->encode(dict->BoS());
00236
00237 if (get(ng,1,1)) {
00238 ng.freq=1;
00239 setfreq(ng.link,ng.pinfo,ng.freq,1);
00240 }
00241 }
00242
00243
00244 cerr << "compute unigram totfreq \n";
00245 int totf=0;
00246 scan(ng,INIT,1);
00247 while(scan(ng,CONT,1)) {
00248 totf+=getfreq(ng.link,ng.pinfo,1);
00249 }
00250
00251 btotfreq(totf);
00252
00253 corrcounts=1;
00254
00255
00256 }
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321 double interplm::zerofreq(int lev)
00322 {
00323 cerr << "Computing lambda: ...";
00324 ngram ng(dict);
00325 double N=0,N1=0;
00326 scan(ng,INIT,lev);
00327 while(scan(ng,CONT,lev)) {
00328 if ((lev==1) && (*ng.wordp(1)==dict->oovcode()))
00329 continue;
00330 N+=ng.freq;
00331 if (ng.freq==1) N1++;
00332 }
00333 cerr << (double)(N1/N) << "\n";
00334 return N1/N;
00335 }
00336
00337
00338 void interplm::test(char* filename,int size,int backoff,int checkpr,char* outpr)
00339 {
00340
00341 if (size>lmsize()) {
00342 cerr << "test: wrong ngram size\n";
00343 exit(1);
00344 }
00345
00346
00347 mfstream inp(filename,ios::in );
00348
00349 char header[100];
00350 inp >> header;
00351 inp.close();
00352
00353 if (strncmp(header,"nGrAm",5)==0 ||
00354 strncmp(header,"NgRaM",5)==0) {
00355 ngramtable ngt(filename,size,NULL,NULL,NULL,0,0,NULL,0,COUNT);
00356 test_ngt(ngt,size,backoff,checkpr);
00357 } else
00358 test_txt(filename,size,backoff,checkpr,outpr);
00359 }
00360
00361
00362 void interplm::test_txt(char* filename,int size,int ,int checkpr,char* outpr)
00363 {
00364
00365 cerr << "test text " << filename << " ";
00366 mfstream inp(filename,ios::in );
00367 ngram ng(dict);
00368
00369 double n=0,lp=0,pr;
00370 double oov=0;
00371 cout.precision(10);
00372 mfstream outp(outpr?outpr:"/dev/null",ios::out );
00373
00374 if (checkpr)
00375 cerr << "checking probabilities\n";
00376
00377 while(inp >> ng)
00378 if (ng.size>=1) {
00379
00380 ng.size=ng.size>size?size:ng.size;
00381
00382 if (dict->encode(dict->BoS()) != dict->oovcode()) {
00383 if (*ng.wordp(1) == dict->encode(dict->BoS())) {
00384 ng.size=1;
00385 continue;
00386 }
00387 }
00388
00389 pr=prob(ng,ng.size);
00390
00391 if (outpr)
00392 outp << ng << "[" << ng.size << "-gram]" << " " << pr << " " << log(pr)/log(10.0) << std::endl;
00393
00394 lp-=log(pr);
00395
00396 n++;
00397
00398 if (((int) n % 10000)==0) cerr << ".";
00399
00400 if (*ng.wordp(1) == dict->oovcode()) oov++;
00401
00402 if (checkpr) {
00403 double totp=0.0;
00404 int oldw=*ng.wordp(1);
00405 for (int c=0; c<dict->size(); c++) {
00406 *ng.wordp(1)=c;
00407 totp+=prob(ng,ng.size);
00408 }
00409 *ng.wordp(1)=oldw;
00410
00411 if ( totp < (1.0 - 1e-5) || totp > (1.0 + 1e-5))
00412 cout << ng << " " << pr << " [t="<< totp << "] ***\n";
00413 }
00414
00415 }
00416
00417 if (oov && dict->dub()>obswrd())
00418 lp += oov * log(dict->dub() - obswrd());
00419
00420 cout << "n=" << (int) n << " LP="
00421 << (double) lp
00422 << " PP=" << exp(lp/n)
00423 << " OVVRate=" << (oov)/n
00424
00425
00426
00427 << "\n";
00428
00429
00430 outp.close();
00431 inp.close();
00432 }
00433
00434
00435 void interplm::test_ngt(ngramtable& ngt,int sz,int ,int checkpr)
00436 {
00437
00438 double pr;
00439 int n=0,c=0;
00440 double lp=0;
00441 double oov=0;
00442 cout.precision(10);
00443
00444 if (sz > ngt.maxlevel()) {
00445 cerr << "test_ngt: ngramtable has uncompatible size\n";
00446 exit(1);
00447 }
00448
00449 if (checkpr) cerr << "checking probabilities\n";
00450
00451 cerr << "Computing PP:";
00452
00453 ngram ng(dict);
00454 ngram ng2(ngt.dict);
00455 ngt.scan(ng2,INIT,sz);
00456
00457 while(ngt.scan(ng2,CONT,sz)) {
00458
00459 ng.trans(ng2);
00460
00461 if (dict->encode(dict->BoS()) != dict->oovcode()) {
00462 if (*ng.wordp(1) == dict->encode(dict->BoS())) {
00463 ng.size=1;
00464 continue;
00465 }
00466 }
00467
00468 n+=ng.freq;
00469 pr=prob(ng,sz);
00470
00471 lp-=(ng.freq * log(pr));
00472
00473 if (*ng.wordp(1) == dict->oovcode())
00474 oov+=ng.freq;
00475
00476
00477 if (checkpr) {
00478 double totp=0.0;
00479 for (c=0; c<dict->size(); c++) {
00480 *ng.wordp(1)=c;
00481 totp+=prob(ng,sz);
00482 }
00483
00484 if ( totp < (1.0 - 1e-5) ||
00485 totp > (1.0 + 1e-5))
00486 cout << ng << " " << pr << " [t="<< totp << "] ***\n";
00487
00488 }
00489
00490 if ((++c % 100000)==0) cerr << ".";
00491
00492 }
00493
00494
00495
00496
00497 if (oov && dict->dub()>obswrd())
00498
00499 lp+=oov * log((dict->dub() - obswrd()));
00500
00501 cout << "n=" << (int) n << " LP="
00502 << (double) lp
00503 << " PP=" << exp(lp/n)
00504 << " OVVRate=" << (oov)/n
00505
00506
00507
00508 << "\n";
00509
00510 cout.flush();
00511
00512 }
00513
00514
00515
00516
00517
00518
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529