00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 using namespace std;
00022
00023 #include <cmath>
00024 #include "mfstream.h"
00025 #include "mempool.h"
00026 #include "dictionary.h"
00027 #include "n_gram.h"
00028 #include "ngramtable.h"
00029 #include "interplm.h"
00030 #include "normcache.h"
00031 #include "ngramcache.h"
00032 #include "mdiadapt.h"
00033 #include "shiftlm.h"
00034 #include "linearlm.h"
00035 #include "mixture.h"
00036 #include "cmd.h"
00037 #include "util.h"
00038
00039
00040
00041
00042
00043 static Enum_T SLmTypeEnum [] = {
00044 { (char*)"ModifiedShiftBeta", MOD_SHIFT_BETA },
00045 { (char*)"msb", MOD_SHIFT_BETA },
00046 { (char*)"InterpShiftBeta", SHIFT_BETA },
00047 { (char*)"sb", SHIFT_BETA },
00048 { (char*)"InterpShiftOne", SHIFT_ONE },
00049 { (char*)"s1", SHIFT_ONE },
00050 { (char*)"InterpShiftZero", SHIFT_ZERO },
00051 { (char*)"s0", SHIFT_ZERO },
00052 { (char*)"LinearWittenBell", LINEAR_WB },
00053 { (char*)"wb", LINEAR_WB },
00054 { (char*)"Mixture", MIXTURE },
00055 { (char*)"mix", MIXTURE },
00056 END_ENUM
00057 };
00058
00059
00060 mixture::mixture(bool fulltable,char* sublminfo,int depth,int prunefreq,char* ipfile,char* opfile):
00061 mdiadaptlm((char *)NULL,depth)
00062 {
00063
00064 prunethresh=prunefreq;
00065 ipfname=ipfile;
00066 opfname=opfile;
00067 usefulltable=fulltable;
00068
00069 mfstream inp(sublminfo,ios::in );
00070 if (!inp) {
00071 cerr << "cannot open " << sublminfo << "\n";
00072 exit(1);
00073 }
00074
00075 char line[MAX_LINE];
00076 inp.getline(line,MAX_LINE);
00077
00078 sscanf(line,"%d",&numslm);
00079
00080 sublm=new interplm* [numslm];
00081
00082 cerr << "WARNING: Parameters PruneSingletons (ps) and PruneTopSingletons (pts) are not taken into account for this type of LM (mixture); please specify the singleton pruning policy for each submodel using parameters \"-sps\" and \"-spts\" in the configuraton file\n";
00083
00084 int max_npar=6;
00085 for (int i=0; i<numslm; i++) {
00086 char **par=new char*[max_npar];
00087 par[0]=new char[BUFSIZ];
00088 par[0][0]='\0';
00089
00090 inp.getline(line,MAX_LINE);
00091
00092 const char *const wordSeparators = " \t\r\n";
00093 char *word = strtok(line, wordSeparators);
00094 int j = 1;
00095
00096 while (word){
00097 if (i>max_npar){
00098 std::cerr << "Too many parameters (expected " << max_npar << ")" << std::endl;
00099 exit(1);
00100 }
00101 par[j] = new char[MAX_LINE];
00102 strcpy(par[j],word);
00103
00104 word = strtok(0, wordSeparators);
00105 j++;
00106 }
00107
00108 int actual_npar = j;
00109
00110 char *subtrainfile;
00111 int slmtype;
00112 bool subprunesingletons;
00113 bool subprunetopsingletons;
00114 int subprunefreq;
00115
00116 DeclareParams((char*)
00117 "SubLanguageModelType",CMDENUMTYPE|CMDMSG, &slmtype, SLmTypeEnum, "type of the sub LM",
00118 "slm",CMDENUMTYPE|CMDMSG, &slmtype, SLmTypeEnum, "type of the sub LM",
00119 "sTrainOn",CMDSTRINGTYPE|CMDMSG, &subtrainfile, "training file of the sub LM",
00120 "str",CMDSTRINGTYPE|CMDMSG, &subtrainfile, "training file of the sub LM",
00121 "sPruneThresh",CMDSUBRANGETYPE|CMDMSG, &subprunefreq, 0, 1000, "threshold for pruning the sub LM",
00122 "sp",CMDSUBRANGETYPE|CMDMSG, &subprunefreq, 0, 1000, "threshold for pruning the sub LM",
00123 "sPruneSingletons",CMDBOOLTYPE|CMDMSG, &subprunesingletons, "boolean flag for pruning of singletons of the sub LM (default is true)",
00124 "sps",CMDBOOLTYPE|CMDMSG, &subprunesingletons, "boolean flag for pruning of singletons of the sub LM (default is true)",
00125 "sPruneTopSingletons",CMDBOOLTYPE|CMDMSG, &subprunetopsingletons, "boolean flag for pruning of singletons at the top level of the sub LM (default is false)",
00126 "spts",CMDBOOLTYPE|CMDMSG, &subprunetopsingletons, "boolean flag for pruning of singletons at the top level of the sub LM (default is false)",
00127 (char *)NULL );
00128
00129 subtrainfile=NULL;
00130 slmtype=0;
00131 subprunefreq=-1;
00132 subprunesingletons=true;
00133 subprunetopsingletons=false;
00134
00135 GetParams(&actual_npar, &par, (char*) NULL);
00136
00137
00138 if (!slmtype) {
00139 std::cerr << "The type (-slm) for sub LM number " << i+1 << " is not specified" << std::endl;
00140 exit(1);
00141 }
00142
00143 if (!subtrainfile) {
00144 std::cerr << "The file (-str) for sub lm number " << i+1 << " is not specified" << std::endl;
00145 exit(1);
00146 }
00147
00148 if (subprunefreq==-1) {
00149 std::cerr << "The prune threshold (-sp) for sub lm number " << i+1 << " is not specified" << std::endl;
00150 exit(1);
00151 }
00152
00153 switch (slmtype) {
00154
00155 case LINEAR_WB:
00156 sublm[i]=new linearwb(subtrainfile,depth,subprunefreq,MSHIFTBETA_I);
00157 break;
00158
00159 case SHIFT_BETA:
00160 sublm[i]=new shiftbeta(subtrainfile,depth,subprunefreq,-1,SHIFTBETA_I);
00161 break;
00162
00163 case SHIFT_ONE:
00164 sublm[i]=new shiftbeta(subtrainfile,depth,subprunefreq,SIMPLE_I);
00165 break;
00166
00167 case MOD_SHIFT_BETA:
00168 sublm[i]=new mshiftbeta(subtrainfile,depth,subprunefreq,MSHIFTBETA_I);
00169 break;
00170
00171 case MIXTURE:
00172 sublm[i]=new mixture(usefulltable,subtrainfile,depth,subprunefreq);
00173 break;
00174
00175 default:
00176 cerr << "not implemented yet\n";
00177 exit(1);
00178 };
00179
00180 sublm[i]->prunesingletons(subprunesingletons==true);
00181 sublm[i]->prunetopsingletons(subprunetopsingletons==true);
00182
00183 if (subprunetopsingletons==true)
00184
00185 sublm[i]->prunesingletons(false);
00186
00187
00188 cerr << "eventually generate OOV code of sub lm[" << i << "]\n";
00189 sublm[i]->dict->genoovcode();
00190
00191
00192 dict->augment(sublm[i]->dict);
00193
00194
00195 if(usefulltable) augment(sublm[i]);
00196
00197 }
00198
00199 cerr << "eventually generate OOV code of the mixture\n";
00200 dict->genoovcode();
00201 cerr << "dict size of the mixture:" << dict->size() << "\n";
00202
00203 k1=2;
00204 k2=10;
00205 };
00206
00207 double mixture::reldist(double *l1,double *l2,int n)
00208 {
00209 double dist=0.0,size=0.0;
00210 for (int i=0; i<n; i++) {
00211 dist+=(l1[i]-l2[i])*(l1[i]-l2[i]);
00212 size+=l1[i]*l1[i];
00213 }
00214 return sqrt(dist/size);
00215 }
00216
00217
00218 double rand01()
00219 {
00220 return (double)rand()/(double)RAND_MAX;
00221 }
00222
00223 int mixture::genpmap()
00224 {
00225 dictionary* d=sublm[0]->dict;
00226
00227 cerr << "Computing parameters mapping: ..." << d->size() << " ";
00228 pm=new int[d->size()];
00229
00230 for (int i=0; i<d->size(); i++) pm[i]=0;
00231
00232 pmax=k2-k1+1;
00233
00234 for (int w=0; w<d->size(); w++) {
00235 int f=d->freq(w);
00236 if ((f>k1) && (f<=k2)) pm[w]=f-k1;
00237 else if (f>k2) {
00238 pm[w]=pmax++;
00239 }
00240 }
00241 cerr << "pmax " << pmax << " ";
00242 return 1;
00243 }
00244
00245 int mixture::pmap(ngram ng,int lev)
00246 {
00247
00248 ngram h(sublm[0]->dict);
00249 h.trans(ng);
00250
00251 if (lev<=1) return 0;
00252
00253 if (!sublm[0]->get(h,2,1)) return 0;
00254 return (int) pm[*h.wordp(2)];
00255 }
00256
00257
00258 int mixture::savepar(char* opf)
00259 {
00260 mfstream out(opf,ios::out);
00261
00262 cerr << "saving parameters in " << opf << "\n";
00263 out << lmsize() << " " << pmax << "\n";
00264
00265 for (int i=0; i<=lmsize(); i++)
00266 for (int j=0; j<pmax; j++)
00267 out.writex(l[i][j],sizeof(double),numslm);
00268
00269
00270 return 1;
00271 }
00272
00273
00274 int mixture::loadpar(char* ipf)
00275 {
00276
00277 mfstream inp(ipf,ios::in);
00278
00279 if (!inp) {
00280 cerr << "cannot open file with parameters: " << ipf << "\n";
00281 exit(1);
00282 }
00283
00284 cerr << "loading parameters from " << ipf << "\n";
00285
00286
00287 char header[100];
00288 inp.getline(header,100);
00289 int value1,value2;
00290 sscanf(header,"%d %d",&value1,&value2);
00291
00292 if (value1 != lmsize() || value2 != pmax) {
00293 cerr << "parameter file " << ipf << " is incompatible\n";
00294 exit(1);
00295 }
00296
00297 for (int i=0; i<=lmsize(); i++)
00298 for (int j=0; j<pmax; j++)
00299 inp.readx(l[i][j],sizeof(double),numslm);
00300
00301 return 1;
00302 }
00303
00304 int mixture::train()
00305 {
00306
00307 double zf;
00308
00309 srand(1333);
00310
00311 genpmap();
00312
00313 if (dub()<dict->size()) {
00314 cerr << "\nERROR: DUB value is too small: the LM will possibly compute wrong probabilities if sub-LMs have different vocabularies!\n";
00315 cerr << "This exception should already have been handled before!!!\n";
00316 exit(1);
00317 }
00318
00319 cerr << "mixlm --> DUB: " << dub() << endl;
00320 for (int i=0; i<numslm; i++) {
00321 cerr << i << " sublm --> DUB: " << sublm[i]->dub() << endl;
00322 cerr << "eventually generate OOV code ";
00323 cerr << sublm[i]->dict->encode(sublm[i]->dict->OOV()) << "\n";
00324 sublm[i]->train();
00325 }
00326
00327
00328
00329 for (int i=0; i<=lmsize(); i++) {
00330 l[i]=new double*[pmax];
00331 for (int j=0; j<pmax; j++) {
00332 l[i][j]=new double[numslm];
00333 for (int k=0; k<numslm; k++)
00334 l[i][j][k]=1.0/(double)numslm;
00335 }
00336 }
00337
00338 if (ipfname) {
00339
00340 loadpar(ipfname);
00341 } else {
00342
00343
00344 double oldl[pmax][numslm];
00345 char alive[pmax],used[pmax];
00346 int totalive;
00347
00348 ngram ng(sublm[0]->dict);
00349
00350 for (int lev=1; lev<=lmsize(); lev++) {
00351
00352 zf=sublm[0]->zerofreq(lev);
00353
00354 cerr << "Starting training at lev:" << lev << "\n";
00355
00356 for (int i=0; i<pmax; i++) {
00357 alive[i]=1;
00358 used[i]=0;
00359 }
00360 totalive=1;
00361 int iter=0;
00362 while (totalive && (iter < 20) ) {
00363
00364 iter++;
00365
00366 for (int i=0; i<pmax; i++)
00367 if (alive[i])
00368 for (int j=0; j<numslm; j++) {
00369 oldl[i][j]=l[lev][i][j];
00370 l[lev][i][j]=1.0/(double)numslm;
00371 }
00372
00373 sublm[0]->scan(ng,INIT,lev);
00374 while(sublm[0]->scan(ng,CONT,lev)) {
00375
00376
00377 if ((lev==1) && (*ng.wordp(1)==sublm[0]->dict->oovcode()))
00378 continue;
00379
00380 int par=pmap(ng,lev);
00381 used[par]=1;
00382
00383
00384 if (alive[par]) {
00385
00386 double backoff=(lev>1?prob(ng,lev-1):1);
00387 double denom=0.0;
00388 double* numer = new double[numslm];
00389 double fstar,lambda;
00390
00391
00392
00393 int cv=(int)floor(zf * (double)ng.freq)+1;
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403 for (int i=0; i<numslm; i++) {
00404
00405
00406
00407 sublm[i]->discount(ng,lev,fstar,lambda,(i==0)*(cv));
00408 numer[i]=oldl[par][i]*(fstar + lambda * backoff);
00409
00410 ngram ngslm(sublm[i]->dict);
00411 ngslm.trans(ng);
00412 if ((*ngslm.wordp(1)==sublm[i]->dict->oovcode()) &&
00413 (dict->dub() > sublm[i]->dict->size()))
00414 numer[i]/=(double)(dict->dub() - sublm[i]->dict->size());
00415
00416 denom+=numer[i];
00417 }
00418
00419 for (int i=0; i<numslm; i++) {
00420 l[lev][par][i]+=(ng.freq * (numer[i]/denom));
00421
00422
00423 }
00424 delete []numer;
00425 }
00426 }
00427
00428
00429 totalive=0;
00430 for (int i=0; i<pmax; i++) {
00431 double tot=0;
00432 if (alive[i]) {
00433 for (int j=0; j<numslm; j++) tot+=(l[lev][i][j]);
00434 for (int j=0; j<numslm; j++) l[lev][i][j]/=tot;
00435
00436
00437 if (!used[i] || (reldist(l[lev][i],oldl[i],numslm)<=0.05))
00438 alive[i]=0;
00439 }
00440 totalive+=alive[i];
00441 }
00442
00443 cerr << "Lev " << lev << " iter " << iter << " tot alive " << totalive << "\n";
00444
00445 }
00446 }
00447 }
00448
00449 if (opfname) savepar(opfname);
00450
00451
00452 return 1;
00453 }
00454
00455 int mixture::discount(ngram ng_,int size,double& fstar,double& lambda,int )
00456 {
00457
00458 ngram ng(dict);
00459 ng.trans(ng_);
00460
00461 double lambda2,fstar2;
00462 fstar=0.0;
00463 lambda=0.0;
00464 int p=pmap(ng,size);
00465 assert(p <= pmax);
00466 double lsum=0;
00467
00468
00469 for (int i=0; i<numslm; i++) {
00470 sublm[i]->discount(ng,size,fstar2,lambda2,0);
00471
00472 ngram ngslm(sublm[i]->dict);
00473 ngslm.trans(ng);
00474
00475 if (dict->dub() > sublm[i]->dict->size()){
00476 if (*ngslm.wordp(1) == sublm[i]->dict->oovcode()) {
00477 fstar2/=(double)(sublm[i]->dict->dub() - sublm[i]->dict->size()+1);
00478 }
00479 }
00480
00481
00482 fstar+=(l[size][p][i]*fstar2);
00483 lambda+=(l[size][p][i]*lambda2);
00484 lsum+=l[size][p][i];
00485 }
00486
00487 if (dict->dub() > dict->size())
00488 if (*ng.wordp(1) == dict->oovcode()) {
00489 fstar*=(double)(dict->dub() - dict->size()+1);
00490 }
00491
00492 assert((lsum>LOWER_DOUBLE_PRECISION_OF_1) && (lsum<=UPPER_DOUBLE_PRECISION_OF_1));
00493 return 1;
00494 }
00495
00496
00497
00498 int mixture::get(ngram& ng,int n,int lev)
00499 {
00500
00501 if (usefulltable)
00502 {
00503 return ngramtable::get(ng,n,lev);
00504 }
00505
00506
00507 resetngramtable();
00508
00509
00510 ngram ug(dict,1);
00511 *ug.wordp(1)=*ng.wordp(ng.size);
00512
00513
00514 ngram locng(dict,maxlevel());
00515
00516
00517 for (int i=0; i<numslm; i++) {
00518
00519 ngram subug(sublm[i]->dict,1);
00520 subug.trans(ug);
00521
00522 if (sublm[i]->get(subug,1,1)) {
00523
00524 ngram subng(sublm[i]->dict,maxlevel());
00525 *subng.wordp(maxlevel())=*subug.wordp(1);
00526 sublm[i]->scan(subug.link,subug.info,1,subng,INIT,maxlevel());
00527 while(sublm[i]->scan(subug.link,subug.info,1,subng,CONT,maxlevel())) {
00528 locng.trans(subng);
00529 put(locng);
00530 }
00531 }
00532 }
00533
00534 return ngramtable::get(ng,n,lev);
00535
00536 }
00537
00538
00539
00540
00541
00542
00543
00544