00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #ifndef moses_PhraseTableCreator_h
00023 #define moses_PhraseTableCreator_h
00024
00025 #include <sstream>
00026 #include <iostream>
00027 #include <queue>
00028 #include <vector>
00029 #include <set>
00030 #include <boost/unordered_map.hpp>
00031
00032 #include "moses/InputFileStream.h"
00033 #include "moses/ThreadPool.h"
00034 #include "moses/Util.h"
00035
00036 #include "BlockHashIndex.h"
00037 #include "StringVector.h"
00038 #include "StringVectorTemp.h"
00039 #include "CanonicalHuffman.h"
00040
00041 namespace Moses
00042 {
00043
00044 typedef std::pair<unsigned char, unsigned char> AlignPoint;
00045
00046 template <typename DataType>
00047 class Counter
00048 {
00049 public:
00050 typedef boost::unordered_map<DataType, size_t> FreqMap;
00051 typedef typename FreqMap::iterator iterator;
00052 typedef typename FreqMap::mapped_type mapped_type;
00053 typedef typename FreqMap::value_type value_type;
00054
00055 private:
00056 #ifdef WITH_THREADS
00057 boost::mutex m_mutex;
00058 #endif
00059 FreqMap m_freqMap;
00060 size_t m_maxSize;
00061 std::vector<DataType> m_bestVec;
00062
00063 struct FreqSorter {
00064 bool operator()(const value_type& a, const value_type& b) const {
00065 if(a.second > b.second)
00066 return true;
00067
00068 if(a.second == b.second && a.first > b.first)
00069 return true;
00070 return false;
00071 }
00072 };
00073
00074 public:
00075 Counter() : m_maxSize(0) {}
00076
00077 iterator Begin() {
00078 return m_freqMap.begin();
00079 }
00080
00081 iterator End() {
00082 return m_freqMap.end();
00083 }
00084
00085 void Increase(DataType data) {
00086 #ifdef WITH_THREADS
00087 boost::mutex::scoped_lock lock(m_mutex);
00088 #endif
00089 m_freqMap[data]++;
00090 }
00091
00092 void IncreaseBy(DataType data, size_t num) {
00093 #ifdef WITH_THREADS
00094 boost::mutex::scoped_lock lock(m_mutex);
00095 #endif
00096 m_freqMap[data] += num;
00097 }
00098
00099 mapped_type& operator[](DataType data) {
00100 return m_freqMap[data];
00101 }
00102
00103 size_t Size() {
00104 #ifdef WITH_THREADS
00105 boost::mutex::scoped_lock lock(m_mutex);
00106 #endif
00107 return m_freqMap.size();
00108 }
00109
00110 void Quantize(size_t maxSize) {
00111 #ifdef WITH_THREADS
00112 boost::mutex::scoped_lock lock(m_mutex);
00113 #endif
00114 m_maxSize = maxSize;
00115 std::vector<std::pair<DataType, mapped_type> > freqVec;
00116 freqVec.insert(freqVec.begin(), m_freqMap.begin(), m_freqMap.end());
00117 std::sort(freqVec.begin(), freqVec.end(), FreqSorter());
00118
00119 for(size_t i = 0; i < freqVec.size() && i < m_maxSize; i++)
00120 m_bestVec.push_back(freqVec[i].first);
00121
00122 std::sort(m_bestVec.begin(), m_bestVec.end());
00123
00124 FreqMap t_freqMap;
00125 for(typename std::vector<std::pair<DataType, mapped_type> >::iterator it
00126 = freqVec.begin(); it != freqVec.end(); it++) {
00127 DataType closest = LowerBound(it->first);
00128 t_freqMap[closest] += it->second;
00129 }
00130
00131 m_freqMap.swap(t_freqMap);
00132 }
00133
00134 void Clear() {
00135 #ifdef WITH_THREADS
00136 boost::mutex::scoped_lock lock(m_mutex);
00137 #endif
00138 m_freqMap.clear();
00139 }
00140
00141 DataType LowerBound(DataType data) {
00142 if(m_maxSize == 0 || m_bestVec.size() == 0)
00143 return data;
00144 else {
00145 typename std::vector<DataType>::iterator it
00146 = std::lower_bound(m_bestVec.begin(), m_bestVec.end(), data);
00147 if(it != m_bestVec.end())
00148 return *it;
00149 else
00150 return m_bestVec.back();
00151 }
00152 }
00153 };
00154
00155 class PackedItem
00156 {
00157 private:
00158 long m_line;
00159 std::string m_sourcePhrase;
00160 std::string m_packedTargetPhrase;
00161 size_t m_rank;
00162 float m_score;
00163
00164 public:
00165 PackedItem(long line, std::string sourcePhrase,
00166 std::string packedTargetPhrase, size_t rank,
00167 float m_score = 0);
00168
00169 long GetLine() const;
00170 const std::string& GetSrc() const;
00171 const std::string& GetTrg() const;
00172 size_t GetRank() const;
00173 float GetScore() const;
00174 };
00175
00176 bool operator<(const PackedItem &pi1, const PackedItem &pi2);
00177
00178 class PhraseTableCreator
00179 {
00180 public:
00181 enum Coding { None, REnc, PREnc };
00182
00183 private:
00184 std::string m_inPath;
00185 std::string m_outPath;
00186 std::string m_tempfilePath;
00187
00188 std::FILE* m_outFile;
00189
00190 size_t m_numScoreComponent;
00191 size_t m_sortScoreIndex;
00192 size_t m_warnMe;
00193
00194 Coding m_coding;
00195 size_t m_orderBits;
00196 size_t m_fingerPrintBits;
00197 bool m_useAlignmentInfo;
00198 bool m_multipleScoreTrees;
00199 size_t m_quantize;
00200 size_t m_maxRank;
00201
00202 static std::string m_phraseStopSymbol;
00203 static std::string m_separator;
00204
00205 #ifdef WITH_THREADS
00206 size_t m_threads;
00207 boost::mutex m_mutex;
00208 #endif
00209
00210 BlockHashIndex m_srcHash;
00211 BlockHashIndex m_rnkHash;
00212
00213 size_t m_maxPhraseLength;
00214
00215 std::vector<unsigned> m_ranks;
00216
00217 typedef std::pair<unsigned, unsigned> SrcTrg;
00218 typedef std::pair<std::string, std::string> SrcTrgString;
00219 typedef std::pair<SrcTrgString, float> SrcTrgProb;
00220
00221 struct SrcTrgProbSorter {
00222 bool operator()(const SrcTrgProb& a, const SrcTrgProb& b) const {
00223 if(a.first.first < b.first.first)
00224 return true;
00225
00226 if(a.first.first == b.first.first && a.second > b.second)
00227 return true;
00228
00229 if(a.first.first == b.first.first
00230 && a.second == b.second
00231 && a.first.second < b.first.second)
00232 return true;
00233
00234 return false;
00235 }
00236 };
00237
00238 std::vector<size_t> m_lexicalTableIndex;
00239 std::vector<SrcTrg> m_lexicalTable;
00240
00241 StringVectorTemp<unsigned char, unsigned long, MmapAllocator>*
00242 m_encodedTargetPhrases;
00243
00244 StringVector<unsigned char, unsigned long, MmapAllocator>*
00245 m_compressedTargetPhrases;
00246
00247 boost::unordered_map<std::string, unsigned> m_targetSymbolsMap;
00248 boost::unordered_map<std::string, unsigned> m_sourceSymbolsMap;
00249
00250 typedef Counter<unsigned> SymbolCounter;
00251 typedef Counter<float> ScoreCounter;
00252 typedef Counter<AlignPoint> AlignCounter;
00253
00254 typedef CanonicalHuffman<unsigned> SymbolTree;
00255 typedef CanonicalHuffman<float> ScoreTree;
00256 typedef CanonicalHuffman<AlignPoint> AlignTree;
00257
00258 SymbolCounter m_symbolCounter;
00259 SymbolTree* m_symbolTree;
00260
00261 AlignCounter m_alignCounter;
00262 AlignTree* m_alignTree;
00263
00264 std::vector<ScoreCounter*> m_scoreCounters;
00265 std::vector<ScoreTree*> m_scoreTrees;
00266
00267 std::priority_queue<PackedItem> m_queue;
00268 long m_lastFlushedLine;
00269 long m_lastFlushedSourceNum;
00270 std::string m_lastFlushedSourcePhrase;
00271 std::vector<std::string> m_lastSourceRange;
00272 std::priority_queue<std::pair<float, size_t> > m_rankQueue;
00273 std::vector<std::string> m_lastCollection;
00274
00275 void Save();
00276 void PrintInfo();
00277
00278 void AddSourceSymbolId(std::string& symbol);
00279 unsigned GetSourceSymbolId(std::string& symbol);
00280
00281 void AddTargetSymbolId(std::string& symbol);
00282 unsigned GetTargetSymbolId(std::string& symbol);
00283 unsigned GetOrAddTargetSymbolId(std::string& symbol);
00284
00285 unsigned GetRank(unsigned srcIdx, unsigned trgIdx);
00286
00287 unsigned EncodeREncSymbol1(unsigned symbol);
00288 unsigned EncodeREncSymbol2(unsigned position, unsigned rank);
00289 unsigned EncodeREncSymbol3(unsigned rank);
00290
00291 unsigned EncodePREncSymbol1(unsigned symbol);
00292 unsigned EncodePREncSymbol2(int lOff, int rOff, unsigned rank);
00293
00294 void EncodeTargetPhraseNone(std::vector<std::string>& t,
00295 std::ostream& os);
00296
00297 void EncodeTargetPhraseREnc(std::vector<std::string>& s,
00298 std::vector<std::string>& t,
00299 std::set<AlignPoint>& a,
00300 std::ostream& os);
00301
00302 void EncodeTargetPhrasePREnc(std::vector<std::string>& s,
00303 std::vector<std::string>& t,
00304 std::set<AlignPoint>& a, size_t ownRank,
00305 std::ostream& os);
00306
00307 void EncodeScores(std::vector<float>& scores, std::ostream& os);
00308 void EncodeAlignment(std::set<AlignPoint>& alignment, std::ostream& os);
00309
00310 std::string MakeSourceKey(std::string&);
00311 std::string MakeSourceTargetKey(std::string&, std::string&);
00312
00313 void LoadLexicalTable(std::string filePath);
00314
00315 void CreateRankHash();
00316 void EncodeTargetPhrases();
00317 void CalcHuffmanCodes();
00318 void CompressTargetPhrases();
00319
00320 void AddRankedLine(PackedItem& pi);
00321 void FlushRankedQueue(bool force = false);
00322
00323 std::string EncodeLine(std::vector<std::string>& tokens, size_t ownRank);
00324 void AddEncodedLine(PackedItem& pi);
00325 void FlushEncodedQueue(bool force = false);
00326
00327 std::string CompressEncodedCollection(std::string encodedCollection);
00328 void AddCompressedCollection(PackedItem& pi);
00329 void FlushCompressedQueue(bool force = false);
00330
00331 public:
00332
00333 PhraseTableCreator(std::string inPath,
00334 std::string outPath,
00335 std::string tempfilePath,
00336 size_t numScoreComponent = 5,
00337 size_t sortScoreIndex = 2,
00338 Coding coding = PREnc,
00339 size_t orderBits = 10,
00340 size_t fingerPrintBits = 16,
00341 bool useAlignmentInfo = false,
00342 bool multipleScoreTrees = true,
00343 size_t quantize = 0,
00344 size_t maxRank = 100,
00345 bool warnMe = true
00346 #ifdef WITH_THREADS
00347 , size_t threads = 2
00348 #endif
00349 );
00350
00351 ~PhraseTableCreator();
00352
00353 friend class RankingTask;
00354 friend class EncodingTask;
00355 friend class CompressionTask;
00356 };
00357
00358 class RankingTask
00359 {
00360 private:
00361 #ifdef WITH_THREADS
00362 static boost::mutex m_mutex;
00363 static boost::mutex m_fileMutex;
00364 #endif
00365 static size_t m_lineNum;
00366 InputFileStream& m_inFile;
00367 PhraseTableCreator& m_creator;
00368
00369 public:
00370 RankingTask(InputFileStream& inFile, PhraseTableCreator& creator);
00371 void operator()();
00372 };
00373
00374 class EncodingTask
00375 {
00376 private:
00377 #ifdef WITH_THREADS
00378 static boost::mutex m_mutex;
00379 static boost::mutex m_fileMutex;
00380 #endif
00381 static size_t m_lineNum;
00382 static size_t m_sourcePhraseNum;
00383 static std::string m_lastSourcePhrase;
00384
00385 InputFileStream& m_inFile;
00386 PhraseTableCreator& m_creator;
00387
00388 public:
00389 EncodingTask(InputFileStream& inFile, PhraseTableCreator& creator);
00390 void operator()();
00391 };
00392
00393 class CompressionTask
00394 {
00395 private:
00396 #ifdef WITH_THREADS
00397 static boost::mutex m_mutex;
00398 #endif
00399 static size_t m_collectionNum;
00400 StringVectorTemp<unsigned char, unsigned long, MmapAllocator>&
00401 m_encodedCollections;
00402 PhraseTableCreator& m_creator;
00403
00404 public:
00405 CompressionTask(StringVectorTemp<unsigned char, unsigned long, MmapAllocator>&
00406 encodedCollections, PhraseTableCreator& creator);
00407 void operator()();
00408 };
00409
00410 }
00411
00412 #endif