00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 using namespace std;
00022
00023 #include <cmath>
00024 #include "mfstream.h"
00025 #include "mempool.h"
00026 #include "htable.h"
00027 #include "n_gram.h"
00028 #include "util.h"
00029 #include "dictionary.h"
00030 #include "ngramtable.h"
00031 #include "doc.h"
00032 #include "cplsa.h"
00033
00034 #define MY_RAND (((double)rand()/RAND_MAX)* 2.0 - 1.0)
00035
00036 plsa::plsa(dictionary* dictfile,int top,
00037 char* baseFile,char* featFile,char* hFile,char* wFile,char* tFile)
00038 {
00039
00040 dict = dictfile;
00041
00042 topics=top;
00043
00044 assert (topics>0);
00045
00046 W=new double* [dict->size()+1];
00047 for (int i=0; i<(dict->size()+1); i++) W[i]=new double [topics];
00048
00049 T=new double* [dict->size()+1];
00050 for (int i=0; i<(dict->size()+1); i++) T[i]=new double [topics];
00051
00052 H=new double [topics];
00053
00054 basefname=baseFile;
00055 featfname=featFile;
00056
00057 tfname=tFile;
00058 wfname=wFile;
00059 hinfname=new char[BUFSIZ];
00060 sprintf(hinfname,"%s",hFile);
00061
00062 houtfname=new char[BUFSIZ];
00063 sprintf(houtfname,"%s.out",hinfname);
00064 cerr << "Hfile in:" << hinfname << " out:" << houtfname << "\n";
00065 }
00066
00067 int plsa::initW(double noise,int spectopic)
00068 {
00069
00070 FILE *f;
00071
00072 if (wfname && ((f=fopen(wfname,"r"))!=NULL)) {
00073 fclose(f);
00074 loadW(wfname);
00075 } else {
00076
00077 if (spectopic) {
00078
00079 double TotW=0;
00080 for (int i=0; i<spectopic; i++)
00081 TotW+=W[i][0]=dict->freq(i);
00082 for (int i=0; i<(dict->size()+1); i++)
00083 W[i][0]/=TotW;
00084 }
00085
00086 for (int t=(spectopic?1:0); t<topics; t++) {
00087 double TotW=0;
00088 for (int i=0; i<(dict->size()+1); i++)
00089 TotW+=W[i][t]=1 + noise * MY_RAND;
00090 for (int i=0; i<(dict->size()+1); i++)
00091 W[i][t]/=TotW;
00092 }
00093 }
00094 return 1;
00095 }
00096
00097 int plsa::initH(double noise,int n)
00098 {
00099
00100 FILE *f;
00101
00102 if ((f=fopen(hinfname,"r"))==NULL) {
00103 mfstream hinfd(hinfname,ios::out);
00104 for (int j=0; j<n; j++) {
00105 double TotH=0;
00106 for (int t=0; t<topics; t++) TotH+=H[t]=1+noise * MY_RAND;
00107 for (int t=0; t<topics; t++) H[t]/=TotH;
00108 hinfd.write((const char*)H,topics *sizeof(double));
00109 }
00110 hinfd.close();
00111 } else
00112 fclose(f);
00113 return 1;
00114 }
00115
00116 int plsa::saveWtxt(char* fname)
00117 {
00118
00119 mfstream out(fname,ios::out);
00120 out << topics << "\n";
00121 for (int i=0; i<dict->size(); i++) {
00122 out << dict->decode(i) << " " << dict->freq(i);
00123 double totW=0;
00124 for (int t=0; t<topics; t++) totW+=W[i][t];
00125 out <<"totPr:" << totW << ":";
00126 for (int t=0; t<topics; t++)
00127 out << " " << W[i][t];
00128 out << "\n";
00129 }
00130 out.close();
00131 return 1;
00132 }
00133
00134 int plsa::saveW(char* fname)
00135 {
00136
00137 mfstream out(fname,ios::out);
00138 out.write((const char*)&topics,sizeof(int));
00139 for (int i=0; i<dict->size(); i++)
00140 out.write((const char*)W[i],sizeof(double)*topics);
00141 out.close();
00142 return 1;
00143 }
00144
00145 int plsa::saveT(char* fname)
00146 {
00147 mfstream out(fname,ios::out);
00148 out.write((const char*)&topics,sizeof(int));
00149 for (int i=0; i<dict->size(); i++) {
00150 double totT=0.0;
00151 for (int t=0; t<topics; t++) totT+=T[i][t];
00152 if (totT>0.00001) {
00153 out.write((const char*)&i,sizeof(int));
00154 out.write((const char*)T[i],sizeof(double)*topics);
00155 }
00156 }
00157 out.close();
00158 return 1;
00159 }
00160
00161
00162 int plsa::combineT(char* tlist)
00163 {
00164
00165 double *tvec=new double[topics];
00166 int w;
00167 int to;
00168 char fname[1000];
00169 for (int i=0; i<dict->size(); i++)
00170 for (int t=0; t<topics; t++) T[i][t]=0;
00171
00172 mfstream inp(tlist,ios::in);
00173 while (inp >> fname) {
00174 mfstream tin(fname,ios::in);
00175 tin.read((char *)&to,sizeof(int));
00176 assert(to==topics);
00177 while(!tin.eof()) {
00178 tin.read((char *)&w,sizeof(int));
00179 tin.read((char *)tvec,sizeof(double)*topics);
00180 for (int t=0; t<topics; t++) T[w][t]+=tvec[t];
00181 }
00182 tin.close();
00183 }
00184
00185 delete [] tvec;
00186
00187 for (int t=0; t<topics; t++) {
00188 double Tsum=0;
00189 for (int i=0; i<dict->size(); i++) {
00190 if (T[i][t]==0.0) T[i][t]=1e-10;
00191 Tsum+=T[i][t];
00192 }
00193 for (int i=0; i<dict->size(); i++) W[i][t]=T[i][t]/Tsum;
00194 }
00195
00196 return 1;
00197 }
00198
00199 int plsa::loadW(char* fname)
00200 {
00201 int r;
00202 mfstream inp(fname,ios::in);
00203 inp.read((char *)&r,sizeof(int));
00204
00205 if (topics>0 && r != topics) {
00206 cerr << "incompatible number of topics: " << r << "\n";
00207 exit(2);
00208 } else
00209 topics=r;
00210
00211 for (int i=0; i<dict->size(); i++)
00212 inp.read((char *)W[i],sizeof(double)*topics);
00213
00214 return 1;
00215 }
00216
00217 int plsa::saveFeat(char* fname)
00218 {
00219
00220
00221 double *WH=new double [dict->size()];
00222 for (int i=0; i<dict->size(); i++) {
00223 WH[i]=0;
00224 for (int t=0; t<topics; t++)
00225 WH[i]+=W[i][t]*H[t];
00226 }
00227
00228 double maxp=WH[0];
00229 for (int i=1; i<dict->size(); i++)
00230 if (WH[i]>maxp) maxp=WH[i];
00231
00232 cerr << "Get max prob" << maxp << "\n";
00233
00234 mfstream out(fname,ios::out);
00235 ngramtable ngt(NULL,1,NULL,NULL,NULL,0,0,NULL,0,COUNT);
00236 ngt.dict->incflag(1);
00237
00238 ngram ng(dict,1);
00239 ngram ng2(ngt.dict,1);
00240
00241 for (int i=0; i<dict->size(); i++) {
00242 *ng.wordp(1)=i;
00243 ng.freq=(int)floor((WH[i]/maxp) * 1000000);
00244 if (ng.freq) {
00245 ng2.trans(ng);
00246 ng2.freq=ng.freq;
00247
00248 ngt.put(ng2);
00249 ngt.dict->incfreq(*ng2.wordp(1),ng2.freq);
00250 }
00251 }
00252
00253 ngt.dict->incflag(0);
00254 ngt.savetxt(fname,1,1);
00255
00256 return 1;
00257 }
00258
00259
00260 int plsa::train(char *trainfile,int maxiter,double noiseH,int flagW,double noiseW,int spectopic)
00261 {
00262
00263 int dsize=dict->size();
00264
00265 srand(100);
00266
00267 if (flagW) {
00268
00269 initW(noiseW,spectopic);
00270 }
00271
00272 doc trset(dict,trainfile);
00273 trset.open();
00274
00275 initH(noiseH,trset.n);
00276
00277
00278 double *WH=new double [dsize];
00279
00280
00281 char cmd[100];
00282 sprintf(cmd,"mv %s %s",houtfname,hinfname);
00283
00284
00285
00286 double lastLL=10;
00287 double LL=-1e+99;
00288
00289 int iter=0;
00290 int r=topics;
00291
00292
00293 while (iter < maxiter)
00294
00295 {
00296 lastLL=LL;
00297 LL=0;
00298
00299 if (flagW)
00300 for (int i=0; i<dict->size(); i++)
00301 for (int t=0; t<r; t++)
00302 T[i][t]=0;
00303
00304 {
00305
00306 mfstream hindf(hinfname,ios::in);
00307 mfstream houtdf(houtfname,ios::out);
00308
00309 while(trset.read()) {
00310
00311 int m=trset.m;
00312
00313 int j=trset.cd;
00314 int N=0;
00315
00316
00317 hindf.read((char *)H,topics * sizeof(double));
00318
00319
00320 for (int i=0; i<m; i++) {
00321 WH[trset.V[i]]=0;
00322 N+=trset.N[trset.V[i]];
00323 for (int t=0; t<r; t++)
00324 WH[trset.V[i]]+=W[trset.V[i]][t]*H[t];
00325 LL+=trset.N[trset.V[i]] * log( WH[trset.V[i]] );
00326 }
00327
00328
00329 if (flagW) {
00330 for (int i=0; i<m; i++) {
00331 for (int t=0; t<r; t++)
00332 T[trset.V[i]][t]+=
00333 (trset.N[trset.V[i]] * W[trset.V[i]][t] *
00334 H[t]/WH[trset.V[i]]);
00335 }
00336 }
00337
00338
00339 double totH=0;
00340 for (int t=0; t<r; t++) {
00341 double tmpHaj=0;
00342 for (int i=0; i<m; i++)
00343 tmpHaj+=(trset.N[trset.V[i]] * W[trset.V[i]][t] *
00344 H[t]/WH[trset.V[i]]);
00345 H[t]=tmpHaj/(double)N;
00346 totH+=H[t];
00347 }
00348
00349 if(totH>UPPER_SINGLE_PRECISION_OF_1 || totH<LOWER_SINGLE_PRECISION_OF_1) {
00350 cerr << "totH=" << totH << "\n";
00351 exit(1);
00352 }
00353
00354
00355 houtdf.write((const char*)H,topics * sizeof(double));
00356
00357
00358 if (!(j % 10000)) cerr << ".";
00359
00360 }
00361
00362 hindf.close();
00363 houtdf.close();
00364
00365 cerr << cmd <<"\n";
00366 system(cmd);
00367 }
00368
00369
00370 if (flagW) {
00371 cerr <<"end of train file final update of Wia\n";
00372 for (int t=0; t<r; t++) {
00373 double Tsum=0;
00374 for (int i=0; i<dsize; i++) Tsum+=T[i][t];
00375 for (int i=0; i<dsize; i++) W[i][t]=T[i][t]/Tsum;
00376 cerr << "end of normalization\n";
00377 }
00378 }
00379 trset.reset();
00380
00381 cout << "iteration: " << ++iter << " LL: " << LL << "\n";
00382
00383
00384 if (flagW) {
00385 cerr << "Saving base distributions\n";
00386 if (tfname) saveT(tfname);
00387 else saveW(basefname);
00388 }
00389
00390 }
00391
00392 if (!flagW) {
00393 cout << "Saving features\n";
00394 saveFeat(featfname);
00395 }
00396
00397 delete [] WH;
00398 return 1;
00399 }
00400
00401
00402
00403
00404
00405
00406