00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <limits>
00023 #include <iostream>
00024 #include <fstream>
00025 #include "dictionary.h"
00026 #include "n_gram.h"
00027 #include "lmContainer.h"
00028
00029 using namespace irstlm;
00030
00031 #include "IRST.h"
00032 #include "moses/LM/PointerState.h"
00033 #include "moses/TypeDef.h"
00034 #include "moses/Util.h"
00035 #include "moses/FactorCollection.h"
00036 #include "moses/Phrase.h"
00037 #include "moses/InputFileStream.h"
00038 #include "moses/StaticData.h"
00039 #include "moses/TranslationTask.h"
00040
00041 using namespace std;
00042
00043 namespace Moses
00044 {
00045
00046 class IRSTLMState : public PointerState
00047 {
00048 public:
00049 IRSTLMState():PointerState(NULL) {}
00050 IRSTLMState(const void* lms):PointerState(lms) {}
00051 IRSTLMState(const IRSTLMState& copy_from):PointerState(copy_from.lmstate) {}
00052
00053 IRSTLMState& operator=( const IRSTLMState& rhs ) {
00054 lmstate = rhs.lmstate;
00055 return *this;
00056 }
00057
00058 const void* GetState() const {
00059 return lmstate;
00060 }
00061 };
00062
00063 LanguageModelIRST::LanguageModelIRST(const std::string &line)
00064 :LanguageModelSingleFactor(line)
00065 ,m_lmtb_dub(0), m_lmtb_size(0)
00066 {
00067 const StaticData &staticData = StaticData::Instance();
00068 int threadCount = staticData.ThreadCount();
00069 if (threadCount != 1) {
00070 throw runtime_error("Error: " + SPrint(threadCount) + " number of threads specified but IRST LM is not threadsafe.");
00071 }
00072
00073 ReadParameters();
00074
00075 VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_lmtb_dub:|" << m_lmtb_dub << "|" << std::endl);
00076 VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_filePath:|" << m_filePath << "|" << std::endl);
00077 VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_factorType:|" << m_factorType << "|" << std::endl);
00078 VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_lmtb_size:|" << m_lmtb_size << "|" << std::endl);
00079 }
00080
00081 LanguageModelIRST::~LanguageModelIRST()
00082 {
00083
00084 #ifndef WIN32
00085 TRACE_ERR( "reset mmap\n");
00086 if (m_lmtb) m_lmtb->reset_mmap();
00087 #endif
00088
00089 delete m_lmtb;
00090 }
00091
00092
00093 bool LanguageModelIRST::IsUseable(const FactorMask &mask) const
00094 {
00095 bool ret = mask[m_factorType];
00096 return ret;
00097 }
00098
00099 void LanguageModelIRST::Load(AllOptions::ptr const& opts)
00100 {
00101 FactorCollection &factorCollection = FactorCollection::Instance();
00102
00103 m_lmtb = m_lmtb->CreateLanguageModel(m_filePath);
00104 if (m_lmtb_size > 0) m_lmtb->setMaxLoadedLevel(m_lmtb_size);
00105 m_lmtb->load(m_filePath);
00106 d=m_lmtb->getDict();
00107 d->incflag(1);
00108
00109 m_nGramOrder = m_lmtb_size = m_lmtb->maxlevel();
00110
00111
00112
00113 m_unknownId = d->oovcode();
00114 m_empty = -1;
00115
00116 CreateFactors(factorCollection);
00117
00118 VERBOSE(1, GetScoreProducerDescription() << " LanguageModelIRST::Load() m_unknownId=" << m_unknownId << std::endl);
00119
00120
00121 m_lmtb->init_caches(m_lmtb_size>2?m_lmtb_size-1:2);
00122
00123 if (m_lmtb_dub > 0) m_lmtb->setlogOOVpenalty(m_lmtb_dub);
00124 }
00125
00126 void LanguageModelIRST::CreateFactors(FactorCollection &factorCollection)
00127 {
00128
00129
00130 std::map<size_t, int> lmIdMap;
00131 size_t maxFactorId = 0;
00132 m_empty = -1;
00133
00134 dict_entry *entry;
00135 dictionary_iter iter(d);
00136 while ( (entry = iter.next()) != NULL) {
00137 size_t factorId = factorCollection.AddFactor(Output, m_factorType, entry->word)->GetId();
00138 lmIdMap[factorId] = entry->code;
00139 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00140 }
00141
00142 size_t factorId;
00143
00144 m_sentenceStart = factorCollection.AddFactor(Output, m_factorType, BOS_);
00145 factorId = m_sentenceStart->GetId();
00146 const std::string bs = BOS_;
00147 const std::string es = EOS_;
00148 m_lmtb_sentenceStart=lmIdMap[factorId] = GetLmID(BOS_);
00149 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00150 m_sentenceStartWord[m_factorType] = m_sentenceStart;
00151
00152 m_sentenceEnd = factorCollection.AddFactor(Output, m_factorType, EOS_);
00153 factorId = m_sentenceEnd->GetId();
00154 m_lmtb_sentenceEnd=lmIdMap[factorId] = GetLmID(EOS_);
00155 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00156 m_sentenceEndWord[m_factorType] = m_sentenceEnd;
00157
00158
00159 m_lmIdLookup.resize(maxFactorId+1);
00160 fill(m_lmIdLookup.begin(), m_lmIdLookup.end(), m_empty);
00161
00162 map<size_t, int>::iterator iterMap;
00163 for (iterMap = lmIdMap.begin() ; iterMap != lmIdMap.end() ; ++iterMap) {
00164 m_lmIdLookup[iterMap->first] = iterMap->second;
00165 }
00166 }
00167
00168 int LanguageModelIRST::GetLmID( const std::string &str ) const
00169 {
00170 return d->encode( str.c_str() );
00171 }
00172
00173 int LanguageModelIRST::GetLmID( const Word &word ) const
00174 {
00175 return GetLmID( word.GetFactor(m_factorType) );
00176 }
00177
00178 int LanguageModelIRST::GetLmID( const Factor *factor ) const
00179 {
00180 size_t factorId = factor->GetId();
00181
00182 if ((factorId >= m_lmIdLookup.size()) || (m_lmIdLookup[factorId] == m_empty)) {
00183 if (d->incflag()==1) {
00184 std::string s = factor->GetString().as_string();
00185 int code = d->encode(s.c_str());
00186
00195
00197
00214
00215
00216 if (factorId >= m_lmIdLookup.size()) {
00217
00218
00219 m_lmIdLookup.resize(factorId+10, m_empty);
00220 }
00221
00222
00223 m_lmIdLookup[factorId] = code;
00224 return code;
00225
00226 } else {
00227 return m_unknownId;
00228 }
00229 } else {
00230 return m_lmIdLookup[factorId];
00231 }
00232 }
00233
00234 const FFState* LanguageModelIRST::EmptyHypothesisState(const InputType &) const
00235 {
00236 std::auto_ptr<IRSTLMState> ret(new IRSTLMState());
00237
00238 return ret.release();
00239 }
00240
00241 void LanguageModelIRST::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
00242 {
00243 fullScore = 0;
00244 ngramScore = 0;
00245 oovCount = 0;
00246
00247 if ( !phrase.GetSize() ) return;
00248
00249 int _min = min(m_lmtb_size - 1, (int) phrase.GetSize());
00250
00251 int codes[m_lmtb_size];
00252 int idx = 0;
00253 codes[idx] = m_lmtb_sentenceStart;
00254 ++idx;
00255 int position = 0;
00256
00257 char* msp = NULL;
00258 float before_boundary = 0.0;
00259 for (; position < _min; ++position) {
00260 codes[idx] = GetLmID(phrase.GetWord(position));
00261 if (codes[idx] == m_unknownId) ++oovCount;
00262 before_boundary += m_lmtb->clprob(codes,idx+1,NULL,NULL,&msp);
00263 ++idx;
00264 }
00265
00266 ngramScore = 0.0;
00267 int end_loop = (int) phrase.GetSize();
00268
00269 for (; position < end_loop; ++position) {
00270 for (idx = 1; idx < m_lmtb_size; ++idx) {
00271 codes[idx-1] = codes[idx];
00272 }
00273 codes[idx-1] = GetLmID(phrase.GetWord(position));
00274 if (codes[idx-1] == m_unknownId) ++oovCount;
00275 ngramScore += m_lmtb->clprob(codes,idx,NULL,NULL,&msp);
00276 }
00277 before_boundary = TransformLMScore(before_boundary);
00278 ngramScore = TransformLMScore(ngramScore);
00279 fullScore = ngramScore + before_boundary;
00280 }
00281
00282 FFState* LanguageModelIRST::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const
00283 {
00284 if (!hypo.GetCurrTargetLength()) {
00285 std::auto_ptr<IRSTLMState> ret(new IRSTLMState(ps));
00286 return ret.release();
00287 }
00288
00289
00290 const int begin = (const int) hypo.GetCurrTargetWordsRange().GetStartPos();
00291 const int end = (const int) hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
00292 const int adjust_end = (const int) std::min(end, begin + m_lmtb_size - 1);
00293
00294
00295
00296
00297 int codes[m_lmtb_size];
00298 int idx=m_lmtb_size-1;
00299 int position = (const int) begin;
00300 while (position >= 0) {
00301 codes[idx] = GetLmID(hypo.GetWord(position));
00302 --idx;
00303 --position;
00304 }
00305 while (idx>=0) {
00306 codes[idx] = m_lmtb_sentenceStart;
00307 --idx;
00308 }
00309
00310 char* msp = NULL;
00311 float score = m_lmtb->clprob(codes,m_lmtb_size,NULL,NULL,&msp);
00312
00313 position = (const int) begin+1;
00314 while (position < adjust_end) {
00315 for (idx=1; idx<m_lmtb_size; idx++) {
00316 codes[idx-1] = codes[idx];
00317 }
00318 codes[idx-1] = GetLmID(hypo.GetWord(position));
00319 score += m_lmtb->clprob(codes,m_lmtb_size,NULL,NULL,&msp);
00320 ++position;
00321 }
00322
00323
00324
00325 if (hypo.IsSourceCompleted()) {
00326 idx=m_lmtb_size-1;
00327 codes[idx] = m_lmtb_sentenceEnd;
00328 --idx;
00329 position = (const int) end - 1;
00330 while (position >= 0 && idx >= 0) {
00331 codes[idx] = GetLmID(hypo.GetWord(position));
00332 --idx;
00333 --position;
00334 }
00335 while (idx>=0) {
00336 codes[idx] = m_lmtb_sentenceStart;
00337 --idx;
00338 }
00339 score += m_lmtb->clprob(codes,m_lmtb_size,NULL,NULL,&msp);
00340 } else {
00341
00342
00343 if (adjust_end < end) {
00344 position = (const int) end - 1;
00345 for (idx=m_lmtb_size-1; idx>0; --idx) {
00346 codes[idx] = GetLmID(hypo.GetWord(position));
00347 }
00348 codes[idx] = m_lmtb_sentenceStart;
00349 msp = (char *) m_lmtb->cmaxsuffptr(codes,m_lmtb_size);
00350 }
00351 }
00352
00353 score = TransformLMScore(score);
00354 out->PlusEquals(this, score);
00355
00356 std::auto_ptr<IRSTLMState> ret(new IRSTLMState(msp));
00357
00358 return ret.release();
00359 }
00360
00361 LMResult LanguageModelIRST::GetValue(const vector<const Word*> &contextFactor, State* finalState) const
00362 {
00363
00364 size_t count = contextFactor.size();
00365 if (count < 0) {
00366 cerr << "ERROR count < 0\n";
00367 exit(100);
00368 };
00369
00370
00371 int codes[MAX_NGRAM_SIZE];
00372
00373 size_t idx=0;
00374
00375
00376 if (count < (size_t) (m_lmtb_size-1)) codes[idx++] = m_lmtb_sentenceEnd;
00377 if (count < (size_t) m_lmtb_size) codes[idx++] = m_lmtb_sentenceStart;
00378
00379 for (size_t i = 0 ; i < count ; i++) {
00380 codes[idx] = GetLmID(*contextFactor[i]);
00381 ++idx;
00382 }
00383
00384 LMResult result;
00385 result.unknown = (codes[idx - 1] == m_unknownId);
00386
00387 char* msp = NULL;
00388 result.score = m_lmtb->clprob(codes,idx,NULL,NULL,&msp);
00389
00390 if (finalState) *finalState=(State *) msp;
00391
00392 result.score = TransformLMScore(result.score);
00393
00394 return result;
00395 }
00396
00397 bool LMCacheCleanup(const int sentences_done, const size_t m_lmcache_cleanup_threshold)
00398 {
00399 if (sentences_done==-1) return true;
00400 if (m_lmcache_cleanup_threshold)
00401 if (sentences_done % m_lmcache_cleanup_threshold == 0)
00402 return true;
00403 return false;
00404 }
00405
00406 void LanguageModelIRST::InitializeForInput(ttasksptr const& ttask)
00407 {
00408
00409 #ifdef TRACE_CACHE
00410 m_lmtb->sentence_id++;
00411 #endif
00412 }
00413
00414 void LanguageModelIRST::CleanUpAfterSentenceProcessing(const InputType& source)
00415 {
00416 const StaticData &staticData = StaticData::Instance();
00417 static int sentenceCount = 0;
00418 sentenceCount++;
00419
00420 size_t lmcache_cleanup_threshold = staticData.GetLMCacheCleanupThreshold();
00421
00422 if (LMCacheCleanup(sentenceCount, lmcache_cleanup_threshold)) {
00423 TRACE_ERR( "reset caches\n");
00424 m_lmtb->reset_caches();
00425 }
00426 }
00427
00428 void LanguageModelIRST::SetParameter(const std::string& key, const std::string& value)
00429 {
00430 if (key == "dub") {
00431 m_lmtb_dub = Scan<unsigned int>(value);
00432 } else {
00433 LanguageModelSingleFactor::SetParameter(key, value);
00434 }
00435 m_lmtb_size = m_nGramOrder;
00436 }
00437
00438 }
00439