00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #include <string.h>
00022 #include <stdio.h>
00023 #include <stdlib.h>
00024 #include <assert.h>
00025 #include <math.h>
00026 #include "mfstream.h"
00027 #include "mempool.h"
00028 #include "htable.h"
00029 #include "dictionary.h"
00030 #include "n_gram.h"
00031 #include "ngramtable.h"
00032 #include "ngramcache.h"
00033 #include "normcache.h"
00034 #include "interplm.h"
00035 #include "mdiadapt.h"
00036 #include "shiftlm.h"
00037
00038
00039
00040
00041
00042 shiftone::shiftone(char* ngtfile,int depth,int prunefreq,TABLETYPE tt):
00043 mdiadaptlm(ngtfile,depth,tt)
00044 {
00045 cerr << "Creating LM with ShiftOne smoothing\n";
00046 prunethresh=prunefreq;
00047 cerr << "PruneThresh: " << prunethresh << "\n";
00048
00049 beta=1.0;
00050
00051 };
00052
00053
00054 int shiftone::train()
00055 {
00056 trainunigr();
00057 return 1;
00058 }
00059
00060
00061 int shiftone::discount(ngram ng_,int size,double& fstar,double& lambda, int cv)
00062 {
00063
00064 ngram ng(dict);
00065 ng.trans(ng_);
00066
00067
00068
00069 if (size > 1) {
00070
00071 ngram history=ng;
00072
00073 if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq>cv) &&
00074 ((size < 3) || ((history.freq-cv) > prunethresh))) {
00075
00076
00077
00078 get(ng,size,size);
00079 cv=(cv>ng.freq)?ng.freq:cv;
00080
00081 if (ng.freq > cv) {
00082
00083 fstar=(double)((double)(ng.freq - cv) - beta)/(double)(history.freq-cv);
00084
00085 lambda=beta * ((double)history.succ/(double)(history.freq-cv));
00086
00087 } else {
00088
00089 fstar=0.0;
00090
00091 lambda=beta * ((double)(history.succ-1)/
00092 (double)(history.freq-cv));
00093
00094 }
00095
00096
00097
00098
00099 if (*ng.wordp(1)==dict->oovcode()) {
00100 lambda+=fstar;
00101 fstar=0.0;
00102 } else {
00103 *ng.wordp(1)=dict->oovcode();
00104 if (get(ng,size,size))
00105 lambda+=(double)((double)ng.freq - beta)/(double)(history.freq-cv);
00106 }
00107
00108 } else {
00109 fstar=0;
00110 lambda=1;
00111 }
00112 } else {
00113 fstar=unigr(ng);
00114 lambda=0.0;
00115 }
00116
00117 return 1;
00118 }
00119
00120
00121
00122
00123
00124
00125
00126
00127 shiftbeta::shiftbeta(char* ngtfile,int depth,int prunefreq,double b,TABLETYPE tt):
00128 mdiadaptlm(ngtfile,depth,tt)
00129 {
00130 cerr << "Creating LM with ShiftBeta smoothing\n";
00131
00132 if (b==-1.0 || (b < 1.0 && b >0.0)) {
00133 beta=new double[lmsize()+1];
00134 for (int l=lmsize(); l>1; l--)
00135 beta[l]=b;
00136 } else {
00137 cerr << "shiftbeta: beta must be < 1.0 and > 0\n";
00138 exit (1);
00139 }
00140
00141 prunethresh=prunefreq;
00142 cerr << "PruneThresh: " << prunethresh << "\n";
00143 };
00144
00145
00146
00147 int shiftbeta::train()
00148 {
00149 ngram ng(dict);
00150 int n1,n2;
00151
00152 trainunigr();
00153
00154 beta[1]=0.0;
00155
00156 for (int l=2; l<=lmsize(); l++) {
00157
00158 cerr << "level " << l << "\n";
00159 n1=0;
00160 n2=0;
00161 scan(ng,INIT,l);
00162 while(scan(ng,CONT,l)) {
00163
00164
00165 if (l<lmsize()) {
00166
00167
00168
00169 ngram hg=ng;
00170 get(hg,l,l);
00171 int s1=0;
00172 ngram ng2=hg;
00173 ng2.pushc(0);
00174
00175 succscan(hg,ng2,INIT,l+1);
00176 while(succscan(hg,ng2,CONT,l+1)) {
00177 if (ng2.freq==1) s1++;
00178 }
00179 succ1(hg.link,s1);
00180 }
00181
00182
00183 if (l>1 && ng.containsWord(dict->OOV(),l)) {
00184
00185 continue;
00186 }
00187
00188
00189 if (l>1 && ng.containsWord(dict->EoS(),l-1)) {
00190
00191 continue;
00192 }
00193
00194
00195 if (l==1 && ng.containsWord(dict->BoS(),l)) {
00196
00197 continue;
00198 }
00199
00200 if (ng.freq==1) n1++;
00201 else if (ng.freq==2) n2++;
00202
00203 }
00204
00205 if (beta[l]==-1) {
00206 if (n1>0)
00207 beta[l]=(double)n1/(double)(n1 + 2 * n2);
00208 else {
00209 cerr << "no singletons! \n";
00210 beta[l]=1.0;
00211 }
00212 }
00213 cerr << beta[l] << "\n";
00214 }
00215
00216 return 1;
00217 };
00218
00219
00220
00221 int shiftbeta::discount(ngram ng_,int size,double& fstar,double& lambda, int cv)
00222 {
00223
00224 ngram ng(dict);
00225 ng.trans(ng_);
00226
00227 if (size > 1) {
00228
00229 ngram history=ng;
00230
00231 if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq>cv) &&
00232
00233 ((size < 3) || ((history.freq-cv) > prunethresh ))) {
00234
00235
00236
00237
00238 if (get(ng,size,size) && (!prunesingletons() || ng.freq >1 || size<3)) {
00239
00240 cv=(cv>ng.freq)?ng.freq:cv;
00241
00242 if (ng.freq>cv) {
00243
00244 fstar=(double)((double)(ng.freq - cv) - beta[size])/(double)(history.freq-cv);
00245
00246 lambda=beta[size]*((double)history.succ/(double)(history.freq-cv));
00247
00248 if (size>=3 && prunesingletons())
00249
00250 lambda+=(1.0-beta[size]) * (double)succ1(history.link)/(double)(history.freq-cv);
00251
00252
00253
00254 } else {
00255
00256 fstar=0.0;
00257
00258 lambda=beta[size]*((double)(history.succ-1)/
00259 (double)(history.freq-cv));
00260
00261 if (size>=3 && prunesingletons())
00262 lambda+=(1.0-beta[size]) * (double)(succ1(history.link)-(cv==1 && ng.freq==1?1:0))
00263 /(double)(history.freq-cv);
00264 }
00265 } else {
00266
00267 fstar=0.0;
00268 lambda=beta[size]*(double)history.succ/(double)history.freq;
00269
00270 if (size>=3 && prunesingletons())
00271 lambda+=(1.0-beta[size]) * (double)succ1(history.link)/(double)history.freq;
00272
00273 }
00274
00275
00276
00277 if (*ng.wordp(1)==dict->oovcode()) {
00278 lambda+=fstar;
00279 fstar=0.0;
00280 } else {
00281 *ng.wordp(1)=dict->oovcode();
00282 if (get(ng,size,size) && (!prunesingletons() || ng.freq >1 || size<3))
00283 lambda+=(double)((double)ng.freq - beta[size])/(double)(history.freq-cv);
00284 }
00285
00286 } else {
00287 fstar=0;
00288 lambda=1;
00289 }
00290 } else {
00291 fstar=unigr(ng);
00292 lambda=0.0;
00293 }
00294
00295 return 1;
00296 }
00297
00298
00299
00300
00301
00302 mshiftbeta::mshiftbeta(char* ngtfile,int depth,int prunefreq,TABLETYPE tt):
00303 mdiadaptlm(ngtfile,depth,tt)
00304 {
00305 cerr << "Creating LM with Modified ShiftBeta smoothing\n";
00306
00307 prunethresh=prunefreq;
00308 cerr << "PruneThresh: " << prunethresh << "\n";
00309
00310 beta[1][0]=0.0;
00311 beta[1][1]=0.0;
00312 beta[1][2]=0.0;
00313
00314 };
00315
00316
00317 int mshiftbeta::train()
00318 {
00319
00320 trainunigr();
00321
00322 gencorrcounts();
00323 gensuccstat();
00324
00325 ngram ng(dict);
00326 int n1,n2,n3,n4;
00327 int unover3=0;
00328
00329 oovsum=0;
00330
00331 for (int l=1; l<=lmsize(); l++) {
00332
00333 cerr << "level " << l << "\n";
00334
00335 cerr << "computing statistics\n";
00336
00337 n1=0;
00338 n2=0;
00339 n3=0,n4=0;
00340
00341 scan(ng,INIT,l);
00342
00343 while(scan(ng,CONT,l)) {
00344
00345
00346 if (l>1 && ng.containsWord(dict->OOV(),l)) {
00347
00348 continue;
00349 }
00350
00351
00352 if (l>1 && ng.containsWord(dict->EoS(),l-1)) {
00353
00354 continue;
00355 }
00356
00357
00358 if (l==1 && ng.containsWord(dict->BoS(),l)) {
00359
00360 continue;
00361 }
00362
00363 ng.freq=mfreq(ng,l);
00364
00365 if (ng.freq==1) n1++;
00366 else if (ng.freq==2) n2++;
00367 else if (ng.freq==3) n3++;
00368 else if (ng.freq==4) n4++;
00369 if (l==1 && ng.freq >=3) unover3++;
00370
00371 }
00372
00373 if (l==1) {
00374 cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << " unover3: " << unover3 << "\n";
00375 } else {
00376 cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << "\n";
00377 }
00378
00379 if (n1 == 0 || n2 == 0 || n1 <= n2) {
00380 cerr << "Error: lower order count-of-counts cannot be estimated properly\n";
00381 cerr << "Hint: use another smoothing method with this corpus.\n";
00382 exit(1);
00383 }
00384
00385 double Y=(double)n1/(double)(n1 + 2 * n2);
00386 beta[0][l] = Y;
00387
00388 if (n3 ==0 || n4 == 0 || n2 <= n3 || n3 <= n4 ){
00389 cerr << "Warning: higher order count-of-counts cannot be estimated properly\n";
00390 cerr << "Fixing this problem by resorting only on the lower order count-of-counts\n";
00391
00392 beta[1][l] = Y;
00393 beta[2][l] = Y;
00394 }
00395 else{
00396 beta[1][l] = 2 - 3 * Y * n3 / n2;
00397 beta[2][l] = 3 - 4 * Y * n4 / n3;
00398 }
00399
00400 if (beta[1][l] < 0){
00401 cerr << "Warning: discount coefficient is negative \n";
00402 cerr << "Fixing this problem by setting beta to 0 \n";
00403 beta[1][l] = 0;
00404
00405 }
00406
00407
00408 if (beta[2][l] < 0){
00409 cerr << "Warning: discount coefficient is negative \n";
00410 cerr << "Fixing this problem by setting beta to 0 \n";
00411 beta[2][l] = 0;
00412
00413 }
00414
00415
00416 if (l==1)
00417 oovsum=beta[0][l] * (double) n1 + beta[1][l] * (double)n2 + beta[2][l] * (double)unover3;
00418
00419 cerr << beta[0][l] << " " << beta[1][l] << " " << beta[2][l] << "\n";
00420 }
00421
00422 return 1;
00423 };
00424
00425
00426
00427 int mshiftbeta::discount(ngram ng_,int size,double& fstar,double& lambda, int cv)
00428 {
00429 ngram ng(dict);
00430 ng.trans(ng_);
00431
00432
00433
00434 if (size > 1) {
00435
00436 ngram history=ng;
00437
00438
00439 if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq > cv) &&
00440 ((size < 3) || ((history.freq-cv) > prunethresh ))) {
00441
00442 int suc[3];
00443 suc[0]=succ1(history.link);
00444 suc[1]=succ2(history.link);
00445 suc[2]=history.succ-suc[0]-suc[1];
00446
00447
00448 if (get(ng,size,size) &&
00449 (!prunesingletons() || mfreq(ng,size)>1 || size<3) &&
00450 (!prunetopsingletons() || mfreq(ng,size)>1 || size<maxlevel())) {
00451
00452 ng.freq=mfreq(ng,size);
00453
00454 cv=(cv>ng.freq)?ng.freq:cv;
00455
00456 if (ng.freq>cv) {
00457
00458 double b=(ng.freq-cv>=3?beta[2][size]:beta[ng.freq-cv-1][size]);
00459
00460 fstar=(double)((double)(ng.freq - cv) - b)/(double)(history.freq-cv);
00461
00462 lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
00463 /
00464 (double)(history.freq-cv);
00465
00466 if ((size>=3 && prunesingletons()) ||
00467 (size==maxlevel() && prunetopsingletons()))
00468
00469 lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
00470
00471 } else {
00472
00473
00474 ng.freq>=3?suc[2]--:suc[ng.freq-1]--;
00475
00476 fstar=0.0;
00477 lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
00478 /
00479 (double)(history.freq-cv);
00480
00481 if ((size>=3 && prunesingletons()) ||
00482 (size==maxlevel() && prunetopsingletons()))
00483 lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
00484
00485 ng.freq>=3?suc[2]++:suc[ng.freq-1]++;
00486 }
00487 } else {
00488 fstar=0.0;
00489 lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
00490 /
00491 (double)(history.freq-cv);
00492
00493 if ((size>=3 && prunesingletons()) ||
00494 (size==maxlevel() && prunetopsingletons()))
00495 lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
00496
00497 }
00498
00499
00500
00501
00502 if (*ng.wordp(1)==dict->oovcode()) {
00503 lambda+=fstar;
00504 fstar=0.0;
00505 } else {
00506 *ng.wordp(1)=dict->oovcode();
00507 if (get(ng,size,size)) {
00508 ng.freq=mfreq(ng,size);
00509 if ((!prunesingletons() || mfreq(ng,size)>1 || size<3) &&
00510 (!prunetopsingletons() || mfreq(ng,size)>1 || size<maxlevel())) {
00511 double b=(ng.freq>=3?beta[2][size]:beta[ng.freq-1][size]);
00512 lambda+=(double)(ng.freq - b)/(double)(history.freq-cv);
00513 }
00514 }
00515 }
00516 } else {
00517 fstar=0;
00518 lambda=1;
00519 }
00520 } else {
00521
00522 lambda=0.0;
00523
00524 int unigrtotfreq=(size<lmsize()?btotfreq():totfreq());
00525
00526
00527
00528 if (get(ng,size,size))
00529 fstar=(double) mfreq(ng,size)/(double)unigrtotfreq;
00530 else {
00531 cerr << "Missing probability for word: " << dict->decode(*ng.wordp(1)) << "\n";
00532 exit(1);
00533 }
00534 }
00535
00536 return 1;
00537 }
00538
00539
00540
00541 int symshiftbeta::discount(ngram ng_,int size,double& fstar,double& lambda, int )
00542 {
00543 ngram ng(dict);
00544 ng.trans(ng_);
00545
00546
00547
00548
00549
00550
00551 assert(size<=2);
00552
00553 if (size == 3) {
00554
00555 ngram history=ng;
00556
00557
00558 }
00559 if (size == 2) {
00560
00561
00562 ngram unig(dict,1);
00563 *unig.wordp(1)=*ng.wordp(2);
00564 double prunig=unigr(unig);
00565
00566
00567 if (*ng.wordp(1) > *ng.wordp(2)) {
00568 int tmp=*ng.wordp(1);
00569 *ng.wordp(1)=*ng.wordp(2);
00570 *ng.wordp(2)=tmp;
00571 }
00572
00573 lambda=beta[2] * (double) entries(2)/(double)totfreq();
00574
00575 if (get(ng,2,2)) {
00576 fstar=(double)((double)ng.freq - beta[2])/
00577 (totfreq() * prunig);
00578 } else {
00579 fstar=0;
00580 }
00581 } else {
00582 fstar=unigr(ng);
00583 lambda=0.0;
00584 }
00585 return 1;
00586 }
00587
00588
00589
00590
00591
00592
00593
00594
00595
00596
00597
00598
00599
00600
00601
00602
00603