00001
00002
00003 #include <algorithm>
00004 #include "moses/FF/FFState.h"
00005 #include "DALMWrapper.h"
00006 #include "dalm.h"
00007 #include "moses/FactorTypeSet.h"
00008 #include "moses/FactorCollection.h"
00009 #include "moses/InputFileStream.h"
00010 #include "util/exception.hh"
00011 #include "moses/ChartHypothesis.h"
00012 #include "moses/ChartManager.h"
00013
00014 using namespace std;
00015
00017 void read_ini(const char *inifile, string &model, string &words, string &wordstxt)
00018 {
00019 ifstream ifs(inifile);
00020 string line;
00021
00022 getline(ifs, line);
00023 while(ifs) {
00024 unsigned int pos = line.find("=");
00025 string key = line.substr(0, pos);
00026 string value = line.substr(pos+1, line.size()-pos);
00027 if(key=="MODEL") {
00028 model = value;
00029 } else if(key=="WORDS") {
00030 words = value;
00031 } else if(key=="WORDSTXT") {
00032 wordstxt = value;
00033 }
00034 getline(ifs, line);
00035 }
00036 }
00038
00039 namespace Moses
00040 {
00041
00042 class Murmur: public DALM::State::HashFunction
00043 {
00044 public:
00045 Murmur(std::size_t seed=0): seed(seed) {
00046 }
00047 virtual std::size_t operator()(const DALM::VocabId *words, std::size_t size) const {
00048 return util::MurmurHashNative(words, sizeof(DALM::VocabId) * size, seed);
00049 }
00050 private:
00051 std::size_t seed;
00052 };
00053
00054 class DALMState : public FFState
00055 {
00056 private:
00057 DALM::State state;
00058
00059 public:
00060 DALMState() {
00061 }
00062
00063 DALMState(const DALMState &from) {
00064 state = from.state;
00065 }
00066
00067 virtual ~DALMState() {
00068 }
00069
00070 void reset(const DALMState &from) {
00071 state = from.state;
00072 }
00073
00074 virtual int Compare(const FFState& other) const {
00075 const DALMState &o = static_cast<const DALMState &>(other);
00076 if(state.get_count() < o.state.get_count()) return -1;
00077 else if(state.get_count() > o.state.get_count()) return 1;
00078 else return state.compare(o.state);
00079 }
00080
00081 virtual size_t hash() const {
00082
00083 return state.hash(Murmur());
00084 }
00085
00086 virtual bool operator==(const FFState& other) const {
00087 const DALMState &o = static_cast<const DALMState &>(other);
00088 return state.compare(o.state) == 0;
00089 }
00090
00091 DALM::State &get_state() {
00092 return state;
00093 }
00094
00095 void refresh() {
00096 state.refresh();
00097 }
00098 };
00099
00100 class DALMChartState : public FFState
00101 {
00102 private:
00103 DALM::Fragment prefixFragments[DALM_MAX_ORDER-1];
00104 unsigned char prefixLength;
00105 DALM::State rightContext;
00106 bool isLarge;
00107
00108 public:
00109 DALMChartState()
00110 : prefixLength(0),
00111 isLarge(false) {
00112 }
00113
00114 virtual ~DALMChartState() {
00115 }
00116
00117 inline unsigned char GetPrefixLength() const {
00118 return prefixLength;
00119 }
00120
00121 inline unsigned char &GetPrefixLength() {
00122 return prefixLength;
00123 }
00124
00125 inline const DALM::Fragment *GetPrefixFragments() const {
00126 return prefixFragments;
00127 }
00128
00129 inline DALM::Fragment *GetPrefixFragments() {
00130 return prefixFragments;
00131 }
00132
00133 inline const DALM::State &GetRightContext() const {
00134 return rightContext;
00135 }
00136
00137 inline DALM::State &GetRightContext() {
00138 return rightContext;
00139 }
00140
00141 inline bool LargeEnough() const {
00142 return isLarge;
00143 }
00144
00145 inline void SetAsLarge() {
00146 isLarge=true;
00147 }
00148
00149 virtual int Compare(const FFState& other) const {
00150 const DALMChartState &o = static_cast<const DALMChartState &>(other);
00151 if(prefixLength < o.prefixLength) return -1;
00152 if(prefixLength > o.prefixLength) return 1;
00153 if(prefixLength!=0) {
00154 const DALM::Fragment &f = prefixFragments[prefixLength-1];
00155 const DALM::Fragment &of = o.prefixFragments[prefixLength-1];
00156 int ret = DALM::compare_fragments(f,of);
00157 if(ret != 0) return ret;
00158 }
00159 if(isLarge != o.isLarge) return (int)isLarge - (int)o.isLarge;
00160 if(rightContext.get_count() < o.rightContext.get_count()) return -1;
00161 if(rightContext.get_count() > o.rightContext.get_count()) return 1;
00162 return rightContext.compare(o.rightContext);
00163 }
00164
00165 virtual size_t hash() const {
00166
00167 unsigned char add[2];
00168 add[0] = prefixLength;
00169 add[1] = isLarge;
00170 std::size_t seed = util::MurmurHashNative(add, 2, prefixLength ? prefixFragments[prefixLength-1].sid : 0);
00171 return rightContext.hash(Murmur(seed));
00172 }
00173
00174 virtual bool operator==(const FFState& other) const {
00175 const DALMChartState &o = static_cast<const DALMChartState &>(other);
00176
00177
00178 if(prefixLength != o.prefixLength) return false;
00179 const DALM::Fragment &f = prefixFragments[prefixLength-1];
00180 const DALM::Fragment &of = o.prefixFragments[prefixLength-1];
00181 if(DALM::compare_fragments(f, of) != 0) return false;
00182
00183
00184 if(rightContext.get_count() != o.rightContext.get_count()) return false;
00185 return rightContext.compare(o.rightContext) == 0;
00186 }
00187
00188 };
00189
00190 LanguageModelDALM::LanguageModelDALM(const std::string &line)
00191 :LanguageModel(line)
00192 {
00193 ReadParameters();
00194
00195 if (m_factorType == NOT_FOUND) {
00196 m_factorType = 0;
00197 }
00198 }
00199
00200 LanguageModelDALM::~LanguageModelDALM()
00201 {
00202 delete m_logger;
00203 delete m_vocab;
00204 delete m_lm;
00205 }
00206
00207 void LanguageModelDALM::Load(AllOptions::ptr const& opts)
00208 {
00210
00212 string inifile= m_filePath + "/dalm.ini";
00213
00214 string model;
00215 string words;
00216 string wordstxt;
00217 read_ini(inifile.c_str(), model, words, wordstxt);
00218
00219 model = m_filePath + "/" + model;
00220 words = m_filePath + "/" + words;
00221 wordstxt = m_filePath + "/" + wordstxt;
00222
00223 UTIL_THROW_IF(model.empty() || words.empty() || wordstxt.empty(),
00224 util::FileOpenException,
00225 "Failed to read DALM ini file " << m_filePath << ". Probably doesn't exist");
00226
00228
00230
00231
00232 m_logger = new DALM::Logger(stderr);
00233 m_logger->setLevel(DALM::LOGGER_INFO);
00234
00235
00236 m_vocab = new DALM::Vocabulary(words, *m_logger);
00237
00238
00239 m_lm = new DALM::LM(model, *m_vocab, m_nGramOrder, *m_logger);
00240
00241 wid_start = m_vocab->lookup(BOS_);
00242 wid_end = m_vocab->lookup(EOS_);
00243
00244
00245 CreateVocabMapping(wordstxt);
00246
00247 FactorCollection &collection = FactorCollection::Instance();
00248 m_beginSentenceFactor = collection.AddFactor(BOS_);
00249 }
00250
00251 const FFState *LanguageModelDALM::EmptyHypothesisState(const InputType &) const
00252 {
00253 DALMState *s = new DALMState();
00254 m_lm->init_state(s->get_state());
00255 return s;
00256 }
00257
00258 void LanguageModelDALM::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
00259 {
00260 oovCount = 0;
00261 fullScore = 0.0f;
00262 ngramScore = 0.0f;
00263
00264 size_t phraseSize = phrase.GetSize();
00265 if (!phraseSize) return;
00266
00267 size_t currPos = 0;
00268
00269 DALM::State state;
00270
00271 if(phrase.GetWord(0).GetFactor(m_factorType) == m_beginSentenceFactor) {
00272 m_lm->init_state(state);
00273 currPos++;
00274
00275 }
00276
00277 float score;
00278 float prefixScore=0.0f;
00279 float partScore=0.0f;
00280
00281
00282 while (currPos < phraseSize) {
00283 const Word &word = phrase.GetWord(currPos);
00284
00285
00286 if (word.IsNonTerminal()) {
00287
00288 state.refresh();
00289 fullScore += partScore;
00290 partScore = 0.0f;
00291
00292
00293 } else {
00294
00295 DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
00296 score = m_lm->query(wid, state);
00297 partScore += score;
00298
00299
00300 if (wid==m_vocab->unk()) ++oovCount;
00301 }
00302
00303 currPos++;
00304 if (currPos >= m_ContextSize) {
00305 break;
00306 }
00307 }
00308 prefixScore = fullScore + partScore;
00309
00310
00311 while (currPos < phraseSize) {
00312 const Word &word = phrase.GetWord(currPos);
00313
00314
00315 if (word.IsNonTerminal()) {
00316
00317 fullScore += partScore;
00318 partScore = 0.0f;
00319
00320 state.refresh();
00321
00322 } else {
00323
00324 DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
00325 score = m_lm->query(wid, state);
00326 partScore += score;
00327
00328 if (wid==m_vocab->unk()) ++oovCount;
00329 }
00330
00331 currPos++;
00332 }
00333 fullScore += partScore;
00334
00335 ngramScore = TransformLMScore(fullScore - prefixScore);
00336 fullScore = TransformLMScore(fullScore);
00337 }
00338
00339 FFState *LanguageModelDALM::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const
00340 {
00341
00342
00343
00344
00345 const DALMState *dalm_ps = static_cast<const DALMState *>(ps);
00346
00347
00348 if (hypo.GetCurrTargetLength() == 0) {
00349 return dalm_ps ? new DALMState(*dalm_ps) : NULL;
00350 }
00351
00352 const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
00353
00354 const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
00355 const std::size_t adjust_end = std::min(end, begin + m_nGramOrder - 1);
00356
00357 DALMState *dalm_state = new DALMState(*dalm_ps);
00358 DALM::State &state = dalm_state->get_state();
00359 float score = 0.0;
00360 for(std::size_t position=begin; position < adjust_end; position++) {
00361 score += m_lm->query(GetVocabId(hypo.GetWord(position).GetFactor(m_factorType)), state);
00362 }
00363
00364 if (hypo.IsSourceCompleted()) {
00365
00366 std::vector<DALM::VocabId> indices(m_nGramOrder-1);
00367 const DALM::VocabId *last = LastIDs(hypo, &indices.front());
00368 m_lm->set_state(&indices.front(), (last-&indices.front()), state);
00369
00370 score += m_lm->query(wid_end, state);
00371 } else if (adjust_end < end) {
00372
00373 std::vector<DALM::VocabId> indices(m_nGramOrder-1);
00374 const DALM::VocabId *last = LastIDs(hypo, &indices.front());
00375 m_lm->set_state(&indices.front(), (last-&indices.front()), state);
00376 }
00377
00378 score = TransformLMScore(score);
00379 if (OOVFeatureEnabled()) {
00380 std::vector<float> scores(2);
00381 scores[0] = score;
00382 scores[1] = 0.0;
00383 out->PlusEquals(this, scores);
00384 } else {
00385 out->PlusEquals(this, score);
00386 }
00387
00388 return dalm_state;
00389 }
00390
00391 FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const
00392 {
00393
00394 DALMChartState *newState = new DALMChartState();
00395 DALM::State &state = newState->GetRightContext();
00396
00397 DALM::Fragment *prefixFragments = newState->GetPrefixFragments();
00398 unsigned char &prefixLength = newState->GetPrefixLength();
00399
00400
00401 float hypoScore = 0.0;
00402
00403 const TargetPhrase &targetPhrase = hypo.GetCurrTargetPhrase();
00404 size_t hypoSize = targetPhrase.GetSize();
00405
00406
00407 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00408 targetPhrase.GetAlignNonTerm().GetNonTermIndexMap();
00409
00410 size_t phrasePos = 0;
00411
00412
00413 if(hypoSize > 0) {
00414 const Word &word = targetPhrase.GetWord(0);
00415 if(word.GetFactor(m_factorType) == m_beginSentenceFactor) {
00416
00417 m_lm->init_state(state);
00418
00419 newState->SetAsLarge();
00420 phrasePos++;
00421 } else if(word.IsNonTerminal()) {
00422
00423
00424 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[0]);
00425 const DALMChartState* prevState =
00426 static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
00427
00428
00429
00430 (*newState) = (*prevState);
00431
00432 phrasePos++;
00433 }
00434 }
00435
00436
00437 for (; phrasePos < hypoSize; phrasePos++) {
00438
00439
00440 const Word &word = targetPhrase.GetWord(phrasePos);
00441
00442
00443 if (!word.IsNonTerminal()) {
00444 EvaluateTerminal(
00445 word, hypoScore,
00446 newState, state,
00447 prefixFragments, prefixLength
00448 );
00449 }
00450
00451
00452
00453 else {
00454
00455 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
00456 const DALMChartState* prevState =
00457 static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
00458
00459 size_t prevTargetPhraseLength = prevHypo->GetCurrTargetPhrase().GetSize();
00460
00461 EvaluateNonTerminal(
00462 word, hypoScore,
00463 newState, state,
00464 prefixFragments, prefixLength,
00465 prevState, prevTargetPhraseLength
00466 );
00467 }
00468 }
00469 hypoScore = TransformLMScore(hypoScore);
00470 hypoScore -= hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
00471
00472
00473 if (OOVFeatureEnabled()) {
00474 std::vector<float> scores(2);
00475 scores[0] = hypoScore;
00476 scores[1] = 0.0;
00477 out->PlusEquals(this, scores);
00478 } else {
00479 out->PlusEquals(this, hypoScore);
00480 }
00481
00482 return newState;
00483 }
00484
00485 bool LanguageModelDALM::IsUseable(const FactorMask &mask) const
00486 {
00487 return mask[m_factorType];
00488 }
00489
00490 void LanguageModelDALM::CreateVocabMapping(const std::string &wordstxt)
00491 {
00492 InputFileStream vocabStrm(wordstxt);
00493
00494 std::vector< std::pair<std::size_t, DALM::VocabId> > vlist;
00495 string line;
00496 std::size_t max_fid = 0;
00497 while(getline(vocabStrm, line)) {
00498 const Factor *factor = FactorCollection::Instance().AddFactor(line);
00499 std::size_t fid = factor->GetId();
00500 DALM::VocabId wid = m_vocab->lookup(line.c_str());
00501
00502 vlist.push_back(std::pair<std::size_t, DALM::VocabId>(fid, wid));
00503 if(max_fid < fid) max_fid = fid;
00504 }
00505
00506 for(std::size_t i = 0; i < m_vocabMap.size(); i++) {
00507 m_vocabMap[i] = m_vocab->unk();
00508 }
00509
00510 m_vocabMap.resize(max_fid+1, m_vocab->unk());
00511 std::vector< std::pair<std::size_t, DALM::VocabId> >::iterator it = vlist.begin();
00512 while(it != vlist.end()) {
00513 std::pair<std::size_t, DALM::VocabId> &entry = *it;
00514 m_vocabMap[entry.first] = entry.second;
00515
00516 ++it;
00517 }
00518 }
00519
00520 DALM::VocabId LanguageModelDALM::GetVocabId(const Factor *factor) const
00521 {
00522 std::size_t fid = factor->GetId();
00523 return (m_vocabMap.size() > fid)? m_vocabMap[fid] : m_vocab->unk();
00524 }
00525
00526 void LanguageModelDALM::SetParameter(const std::string& key, const std::string& value)
00527 {
00528 if (key == "factor") {
00529 m_factorType = Scan<FactorType>(value);
00530 } else if (key == "order") {
00531 m_nGramOrder = Scan<size_t>(value);
00532 } else if (key == "path") {
00533 m_filePath = value;
00534 } else {
00535 LanguageModel::SetParameter(key, value);
00536 }
00537 m_ContextSize = m_nGramOrder-1;
00538 }
00539
00540 void LanguageModelDALM::EvaluateTerminal(
00541 const Word &word,
00542 float &hypoScore,
00543 DALMChartState *newState,
00544 DALM::State &state,
00545 DALM::Fragment *prefixFragments,
00546 unsigned char &prefixLength) const
00547 {
00548
00549 DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
00550 if (newState->LargeEnough()) {
00551 float score = m_lm->query(wid, state);
00552 hypoScore += score;
00553 } else {
00554 unsigned char prevLen = state.get_count();
00555 float score = m_lm->query(wid, state, prefixFragments[prefixLength]);
00556
00557 if(score > 0) {
00558 hypoScore -= score;
00559 newState->SetAsLarge();
00560 } else if(state.get_count()<=prefixLength) {
00561 hypoScore += score;
00562 prefixLength++;
00563 newState->SetAsLarge();
00564 } else {
00565 hypoScore += score;
00566 prefixLength++;
00567 if(state.get_count() < std::min(prevLen+1, (int)m_ContextSize)) {
00568 newState->SetAsLarge();
00569 }
00570 if(prefixLength >= m_ContextSize) newState->SetAsLarge();
00571 }
00572 }
00573 }
00574
00575 void LanguageModelDALM::EvaluateNonTerminal(
00576 const Word &word,
00577 float &hypoScore,
00578 DALMChartState *newState,
00579 DALM::State &state,
00580 DALM::Fragment *prefixFragments,
00581 unsigned char &prefixLength,
00582 const DALMChartState *prevState,
00583 size_t prevTargetPhraseLength
00584 ) const
00585 {
00586
00587 const unsigned char prevPrefixLength = prevState->GetPrefixLength();
00588 const DALM::Fragment *prevPrefixFragments = prevState->GetPrefixFragments();
00589
00590 if(prevPrefixLength == 0) {
00591 newState->SetAsLarge();
00592 hypoScore += m_lm->sum_bows(state, 0, state.get_count());
00593 state = prevState->GetRightContext();
00594 return;
00595 }
00596 if(!state.has_context()) {
00597 newState->SetAsLarge();
00598 state = prevState->GetRightContext();
00599 return;
00600 }
00601 DALM::Gap gap(state);
00602 unsigned char prevLen = state.get_count();
00603
00604
00605 for(size_t prefixPos = 0; prefixPos < prevPrefixLength; prefixPos++) {
00606 const DALM::Fragment &f = prevPrefixFragments[prefixPos];
00607
00608 if (newState->LargeEnough()) {
00609 float score = m_lm->query(f, state, gap);
00610 hypoScore += score;
00611
00612 if(!gap.is_extended()) {
00613 state = prevState->GetRightContext();
00614 return;
00615 } else if(state.get_count() <= prefixPos+1) {
00616 state = prevState->GetRightContext();
00617 return;
00618 }
00619 } else {
00620 DALM::Fragment &fnew = prefixFragments[prefixLength];
00621 float score = m_lm->query(f, state, gap, fnew);
00622 hypoScore += score;
00623
00624 if(!gap.is_extended()) {
00625 newState->SetAsLarge();
00626 state = prevState->GetRightContext();
00627 return;
00628 } else if(state.get_count() <= prefixPos+1) {
00629 if(state.get_count() == prefixPos+1 && !gap.is_finalized()) {
00630 prefixLength++;
00631 }
00632 newState->SetAsLarge();
00633 state = prevState->GetRightContext();
00634 return;
00635 } else if(gap.is_finalized()) {
00636 newState->SetAsLarge();
00637 } else {
00638 prefixLength++;
00639 if(state.get_count() < std::min(prevLen+1, (int)m_ContextSize)) {
00640 newState->SetAsLarge();
00641 }
00642
00643 if(prefixLength >= m_ContextSize) newState->SetAsLarge();
00644 }
00645 }
00646 gap.succ();
00647 prevLen = state.get_count();
00648 }
00649
00650
00651 if (prevState->LargeEnough()) {
00652 newState->SetAsLarge();
00653
00654 hypoScore += m_lm->sum_bows(state, prevPrefixLength, state.get_count());
00655
00656
00657 state = prevState->GetRightContext();
00658 } else {
00659 m_lm->set_state(state, prevState->GetRightContext(), prevPrefixFragments, gap);
00660 }
00661 }
00662
00663 }