00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 #include <queue>
00011 #include <iomanip>
00012 #include <vector>
00013 #include <iterator>
00014 #include <sstream>
00015 #include <algorithm>
00016 
00017 #include <boost/program_options.hpp>
00018 #include <boost/dynamic_bitset.hpp>
00019 #include <boost/shared_ptr.hpp>
00020 #include <boost/foreach.hpp>
00021 #include <boost/thread.hpp>
00022 #include <boost/math/distributions/binomial.hpp>
00023 #include <boost/unordered_map.hpp>
00024 #include <boost/unordered_set.hpp>
00025 
00026 #include "moses/TranslationModel/UG/generic/program_options/ug_get_options.h"
00027 #include "moses/Util.h"
00028 #include "ug_mm_2d_table.h"
00029 #include "ug_mm_ttrack.h"
00030 #include "ug_corpus_token.h"
00031 
00032 using namespace std;
00033 using namespace sapt;
00034 using namespace ugdiss;
00035 using namespace boost::math;
00036 
00037 typedef mm2dTable<id_type,id_type,uint32_t,uint32_t> LEX_t;
00038 typedef SimpleWordId Token;
00039 
00040 
00041 void interpret_args(int ac, char* av[]);
00042 
00043 mmTtrack<Token> T1,T2;
00044 mmTtrack<char>     Tx;
00045 TokenIndex      V1,V2;
00046 
00047 typedef pair<id_type,id_type> wpair;
00048 struct Count
00049 {
00050   uint32_t a;
00051   uint32_t c;
00052   Count() : a(0), c(0) {};
00053   Count(uint32_t ax, uint32_t cx) : a(ax), c(cx) {}
00054 };
00055 
00056 bool
00057 operator<(pair<id_type,Count> const& a,
00058           pair<id_type,Count> const& b)
00059 {
00060   return a.first < b.first;
00061 }
00062 
00063 
00064 typedef boost::unordered_map<wpair,Count> countmap_t;
00065 typedef vector<vector<pair<id_type,Count> > > countlist_t;
00066 
00067 vector<countlist_t> XLEX;
00068 
00069 class Counter
00070 {
00071 public:
00072   countmap_t  CNT;
00073   countlist_t & LEX;
00074   size_t  offset;
00075   size_t    skip;
00076   Counter(countlist_t& lex, size_t o, size_t s)
00077     : LEX(lex), offset(o), skip(s) {}
00078   void processSentence(id_type sid);
00079   void operator()();
00080 };
00081 
00082 string bname,cfgFile,L1,L2,oname,cooc;
00083 int    verbose;
00084 size_t truncat;
00085 size_t num_threads;
00086 
00087 void
00088 Counter::
00089 operator()()
00090 {
00091   for (size_t sid = offset; sid < min(truncat,T1.size()); sid += skip)
00092     processSentence(sid);
00093 
00094   LEX.resize(V1.ksize());
00095   for (countmap_t::const_iterator c = CNT.begin(); c != CNT.end(); ++c)
00096     {
00097       pair<id_type,Count> foo(c->first.second,c->second);
00098       LEX.at(c->first.first).push_back(foo);
00099     }
00100   typedef vector<pair<id_type,Count> > v_t;
00101   BOOST_FOREACH(v_t& v, LEX)
00102     sort(v.begin(),v.end());
00103 }
00104 
00105 struct lexsorter
00106 {
00107   vector<countlist_t> const& v;
00108   id_type wid;
00109   lexsorter(vector<countlist_t> const& vx, id_type widx)
00110     : v(vx),wid(widx) {}
00111   bool operator()(pair<uint32_t,uint32_t> const& a,
00112                   pair<uint32_t,uint32_t> const& b) const
00113   {
00114     return (v.at(a.first).at(wid).at(a.second).first >
00115             v.at(b.first).at(wid).at(b.second).first);
00116   }
00117 };
00118 
00119 void
00120 writeTableHeader(ostream& out)
00121 {
00122   filepos_type idxOffset=0;
00123   tpt::numwrite(out,idxOffset); 
00124   tpt::numwrite(out,id_type(V1.ksize()));
00125   tpt::numwrite(out,id_type(V2.ksize()));
00126 }
00127 
00128 void writeTable(ostream* aln_out, ostream* coc_out)
00129 {
00130   vector<uint32_t> m1a(V1.ksize(),0); 
00131   vector<uint32_t> m2a(V2.ksize(),0); 
00132   vector<uint32_t> m1c(V1.ksize(),0); 
00133   vector<uint32_t> m2c(V2.ksize(),0); 
00134   vector<id_type> idxa(V1.ksize()+1,0);
00135   vector<id_type> idxc(V1.ksize()+1,0);
00136   if (aln_out) writeTableHeader(*aln_out);
00137   if (coc_out) writeTableHeader(*coc_out);
00138   size_t CellCountA=0,CellCountC=0;
00139   for (size_t id1 = 0; id1 < V1.ksize(); ++id1)
00140     {
00141       idxa[id1] = CellCountA;
00142       idxc[id1] = CellCountC;
00143       lexsorter sorter(XLEX,id1);
00144       vector<pair<uint32_t,uint32_t> > H; H.reserve(num_threads);
00145       for (size_t i = 0; i < num_threads; ++i)
00146         {
00147           if (id1 < XLEX.at(i).size() && XLEX[i][id1].size())
00148             H.push_back(pair<uint32_t,uint32_t>(i,0));
00149         }
00150       if (!H.size()) continue;
00151       make_heap(H.begin(),H.end(),sorter);
00152       while (H.size())
00153         {
00154           id_type  id2 = XLEX[H[0].first][id1][H[0].second].first;
00155           uint32_t aln = XLEX[H[0].first][id1][H[0].second].second.a;
00156           uint32_t coc = XLEX[H[0].first][id1][H[0].second].second.c;
00157           pop_heap(H.begin(),H.end(),sorter);
00158           ++H.back().second;
00159           if (H.back().second == XLEX[H.back().first][id1].size())
00160             H.pop_back();
00161           else
00162             push_heap(H.begin(),H.end(),sorter);
00163           while (H.size() &&
00164                  XLEX[H[0].first][id1].at(H[0].second).first == id2)
00165             {
00166               aln += XLEX[H[0].first][id1][H[0].second].second.a;
00167               coc += XLEX[H[0].first][id1][H[0].second].second.c;
00168               pop_heap(H.begin(),H.end(),sorter);
00169               ++H.back().second;
00170               if (H.back().second == XLEX[H.back().first][id1].size())
00171                 H.pop_back();
00172               else
00173                 push_heap(H.begin(),H.end(),sorter);
00174             }
00175           if (aln_out)
00176             {
00177               ++CellCountA;
00178               tpt::numwrite(*aln_out,id2);
00179               tpt::numwrite(*aln_out,aln);
00180               m1a[id1] += aln;
00181               m2a[id2] += aln;
00182             }
00183           if (coc_out && coc)
00184             {
00185               ++CellCountC;
00186               tpt::numwrite(*coc_out,id2);
00187               tpt::numwrite(*coc_out,coc);
00188               m1c[id1] += coc;
00189               m2c[id2] += coc;
00190             }
00191         }
00192     }
00193   idxa.back() = CellCountA;
00194   idxc.back() = CellCountC;
00195   if (aln_out)
00196     {
00197       filepos_type idxOffsetA = aln_out->tellp();
00198       BOOST_FOREACH(id_type foo, idxa)
00199         tpt::numwrite(*aln_out,foo);
00200       aln_out->write(reinterpret_cast<char const*>(&m1a[0]),m1a.size()*4);
00201       aln_out->write(reinterpret_cast<char const*>(&m2a[0]),m2a.size()*4);
00202       aln_out->seekp(0);
00203       tpt::numwrite(*aln_out,idxOffsetA);
00204     }
00205   if (coc_out)
00206     {
00207       filepos_type idxOffsetC = coc_out->tellp();
00208       BOOST_FOREACH(id_type foo, idxc)
00209         tpt::numwrite(*coc_out,foo);
00210       coc_out->write(reinterpret_cast<char const*>(&m1c[0]),m1c.size()*4);
00211       coc_out->write(reinterpret_cast<char const*>(&m2c[0]),m2c.size()*4);
00212       coc_out->seekp(0);
00213       tpt::numwrite(*coc_out,idxOffsetC);
00214     }
00215 }
00216 
00217 void
00218 Counter::
00219 processSentence(id_type sid)
00220 {
00221   Token const* s1 = T1.sntStart(sid);
00222   Token const* e1 = T1.sntEnd(sid);
00223   Token const* s2 = T2.sntStart(sid);
00224   Token const* e2 = T2.sntEnd(sid);
00225   
00226   
00227   
00228   
00229   
00230   
00231 
00232   
00233   bitvector check1(T1.sntLen(sid)); check1.set();
00234   bitvector check2(T2.sntLen(sid)); check2.set();
00235 
00236   
00237   char const*   p = Tx.sntStart(sid);
00238   char const*   q = Tx.sntEnd(sid);
00239   ushort r,c;
00240   if (verbose && sid % 1000000 == 0)
00241     cerr << sid/1000000 << " M sentences processed" << endl;
00242   while (p < q)
00243     {
00244       p = tpt::binread(p,r);
00245       p = tpt::binread(p,c);
00246       
00247       UTIL_THROW_IF2(r >= check1.size(), "out of bounds at line " << sid);
00248       UTIL_THROW_IF2(c >= check2.size(), "out of bounds at line " << sid);
00249       
00250       
00251       UTIL_THROW_IF2(s1+r >= e1, "out of bounds at line " << sid);
00252       UTIL_THROW_IF2(s2+c >= e2, "out of bounds at line " << sid);
00253       
00254       
00255       check1.reset(r);
00256       check2.reset(c);
00257       id_type id1 = (s1+r)->id();
00258       id_type id2 = (s2+c)->id();
00259       wpair k(id1,id2);
00260       Count& cnt = CNT[k];
00261       cnt.a++;
00262       
00263       
00264     }
00265   
00266   for (size_t i = check1.find_first();
00267        i < check1.size();
00268        i = check1.find_next(i))
00269     CNT[wpair((s1+i)->id(),0)].a++;
00270   for (size_t i = check2.find_first();
00271        i < check2.size();
00272        i = check2.find_next(i))
00273     CNT[wpair(0,(s2+i)->id())].a++;
00274 }
00275 
00276 int
00277 main(int argc, char* argv[])
00278 {
00279   interpret_args(argc,argv);
00280   char c = *bname.rbegin();
00281   if (c != '/' && c != '.') bname += '.';
00282   T1.open(bname+L1+".mct");
00283   T2.open(bname+L2+".mct");
00284   Tx.open(bname+L1+"-"+L2+".mam");
00285   V1.open(bname+L1+".tdx");
00286   V2.open(bname+L2+".tdx");
00287   if (!truncat) truncat = T1.size();
00288   XLEX.resize(num_threads);
00289   vector<boost::shared_ptr<boost::thread> > workers(num_threads);
00290   for (size_t i = 0; i < num_threads; ++i)
00291     workers[i].reset(new boost::thread(Counter(XLEX[i],i,num_threads)));
00292   for (size_t i = 0; i < workers.size(); ++i)
00293     workers[i]->join();
00294   
00295   ofstream aln_out,coc_out;
00296   if (oname.size()) aln_out.open(oname.c_str());
00297   
00298   writeTable(oname.size() ? &aln_out : NULL,
00299              cooc.size()  ? &coc_out : NULL);
00300   if (oname.size()) aln_out.close();
00301   
00302 }
00303 
00304 void
00305 interpret_args(int ac, char* av[])
00306 {
00307   namespace po=boost::program_options;
00308   po::variables_map vm;
00309   po::options_description o("Options");
00310   po::options_description h("Hidden Options");
00311   po::positional_options_description a;
00312 
00313   o.add_options()
00314     ("help,h",    "print this message")
00315     ("cfg,f", po::value<string>(&cfgFile),"config file")
00316     ("oname,o", po::value<string>(&oname),"output file name")
00317     
00318     
00319     ("verbose,v", po::value<int>(&verbose)->default_value(0)->implicit_value(1),
00320      "verbosity level")
00321     ("threads,t", po::value<size_t>(&num_threads)->default_value(4),
00322      "count in <N> parallel threads")
00323     ("truncate,n", po::value<size_t>(&truncat)->default_value(0),
00324      "truncate corpus to <N> sentences (for debugging)")
00325     ;
00326 
00327   h.add_options()
00328     ("bname", po::value<string>(&bname), "base name")
00329     ("L1",    po::value<string>(&L1),"L1 tag")
00330     ("L2",    po::value<string>(&L2),"L2 tag")
00331     ;
00332   a.add("bname",1);
00333   a.add("L1",1);
00334   a.add("L2",1);
00335   get_options(ac,av,h.add(o),a,vm,"cfg");
00336 
00337   if (vm.count("help") || bname.empty() || (oname.empty() && cooc.empty()))
00338     {
00339       cout << "usage:\n\t" << av[0] << " <basename> <L1 tag> <L2 tag> [-o <output file>] [-c <output file>]\n" << endl;
00340       cout << "at least one of -o / -c must be specified." << endl;
00341       cout << o << endl;
00342       exit(0);
00343     }
00344   size_t num_cores = boost::thread::hardware_concurrency();
00345   num_threads = min(num_threads,num_cores);
00346 }
00347 
00348