00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <cstring>
00021 #include <iostream>
00022 #include <memory>
00023 #include <cstdlib>
00024 #include <boost/shared_ptr.hpp>
00025 #include <boost/lexical_cast.hpp>
00026
00027 #include "lm/binary_format.hh"
00028 #include "lm/enumerate_vocab.hh"
00029 #include "lm/left.hh"
00030 #include "lm/model.hh"
00031 #include "util/exception.hh"
00032 #include "util/tokenize_piece.hh"
00033 #include "util/string_stream.hh"
00034
00035 #include "Ken.h"
00036 #include "Base.h"
00037 #include "moses/FF/FFState.h"
00038 #include "moses/TypeDef.h"
00039 #include "moses/Util.h"
00040 #include "moses/FactorCollection.h"
00041 #include "moses/Phrase.h"
00042 #include "moses/InputFileStream.h"
00043 #include "moses/StaticData.h"
00044 #include "moses/ChartHypothesis.h"
00045 #include "moses/Incremental.h"
00046 #include "moses/Syntax/SHyperedge.h"
00047 #include "moses/Syntax/SVertex.h"
00048
00049 using namespace std;
00050
00051 namespace Moses
00052 {
00053 namespace
00054 {
00055
00056 struct KenLMState : public FFState {
00057 lm::ngram::State state;
00058 virtual size_t hash() const {
00059 size_t ret = hash_value(state);
00060 return ret;
00061 }
00062 virtual bool operator==(const FFState& o) const {
00063 const KenLMState &other = static_cast<const KenLMState &>(o);
00064 bool ret = state == other.state;
00065 return ret;
00066 }
00067
00068 };
00069
00070 class MappingBuilder : public lm::EnumerateVocab
00071 {
00072 public:
00073 MappingBuilder(FactorCollection &factorCollection, std::vector<lm::WordIndex> &mapping)
00074 : m_factorCollection(factorCollection), m_mapping(mapping) {}
00075
00076 void Add(lm::WordIndex index, const StringPiece &str) {
00077 std::size_t factorId = m_factorCollection.AddFactor(str)->GetId();
00078 if (m_mapping.size() <= factorId) {
00079
00080 m_mapping.resize(factorId + 1);
00081 }
00082 m_mapping[factorId] = index;
00083 }
00084
00085 private:
00086 FactorCollection &m_factorCollection;
00087 std::vector<lm::WordIndex> &m_mapping;
00088 };
00089
00090 }
00091
00092 template <class Model> void LanguageModelKen<Model>::LoadModel(const std::string &file, util::LoadMethod load_method)
00093 {
00094 m_lmIdLookup.clear();
00095
00096 lm::ngram::Config config;
00097 if(this->m_verbosity >= 1) {
00098 config.messages = &std::cerr;
00099 } else {
00100 config.messages = NULL;
00101 }
00102 FactorCollection &collection = FactorCollection::Instance();
00103 MappingBuilder builder(collection, m_lmIdLookup);
00104 config.enumerate_vocab = &builder;
00105 config.load_method = load_method;
00106
00107 m_ngram.reset(new Model(file.c_str(), config));
00108 VERBOSE(2, "LanguageModelKen " << m_description << " reset to " << file << "\n");
00109 }
00110
00111 template <class Model> LanguageModelKen<Model>::LanguageModelKen(const std::string &line, const std::string &file, FactorType factorType, util::LoadMethod load_method)
00112 :LanguageModel(line)
00113 ,m_beginSentenceFactor(FactorCollection::Instance().AddFactor(BOS_))
00114 ,m_factorType(factorType)
00115 {
00116 ReadParameters();
00117 LoadModel(file, load_method);
00118 }
00119
00120 template <class Model> LanguageModelKen<Model>::LanguageModelKen()
00121 :LanguageModel("KENLM")
00122 ,m_beginSentenceFactor(FactorCollection::Instance().AddFactor(BOS_))
00123 ,m_factorType(0)
00124 {
00125 ReadParameters();
00126 }
00127
00128
00129 template <class Model> LanguageModelKen<Model>::LanguageModelKen(const LanguageModelKen<Model> ©_from)
00130 :LanguageModel(copy_from.GetArgLine()),
00131 m_ngram(copy_from.m_ngram),
00132
00133 m_beginSentenceFactor(copy_from.m_beginSentenceFactor),
00134 m_factorType(copy_from.m_factorType),
00135 m_lmIdLookup(copy_from.m_lmIdLookup)
00136 {
00137 }
00138
00139 template <class Model> const FFState * LanguageModelKen<Model>::EmptyHypothesisState(const InputType &) const
00140 {
00141 KenLMState *ret = new KenLMState();
00142 ret->state = m_ngram->BeginSentenceState();
00143 return ret;
00144 }
00145
00146 template <class Model> void LanguageModelKen<Model>::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
00147 {
00148 fullScore = 0;
00149 ngramScore = 0;
00150 oovCount = 0;
00151
00152 if (!phrase.GetSize()) return;
00153
00154 lm::ngram::ChartState discarded_sadly;
00155 lm::ngram::RuleScore<Model> scorer(*m_ngram, discarded_sadly);
00156
00157 size_t position;
00158 if (m_beginSentenceFactor == phrase.GetWord(0).GetFactor(m_factorType)) {
00159 scorer.BeginSentence();
00160 position = 1;
00161 } else {
00162 position = 0;
00163 }
00164
00165 size_t ngramBoundary = m_ngram->Order() - 1;
00166
00167 size_t end_loop = std::min(ngramBoundary, phrase.GetSize());
00168 for (; position < end_loop; ++position) {
00169 const Word &word = phrase.GetWord(position);
00170 if (word.IsNonTerminal()) {
00171 fullScore += scorer.Finish();
00172 scorer.Reset();
00173 } else {
00174 lm::WordIndex index = TranslateID(word);
00175 scorer.Terminal(index);
00176 if (!index) ++oovCount;
00177 }
00178 }
00179 float before_boundary = fullScore + scorer.Finish();
00180 for (; position < phrase.GetSize(); ++position) {
00181 const Word &word = phrase.GetWord(position);
00182 if (word.IsNonTerminal()) {
00183 fullScore += scorer.Finish();
00184 scorer.Reset();
00185 } else {
00186 lm::WordIndex index = TranslateID(word);
00187 scorer.Terminal(index);
00188 if (!index) ++oovCount;
00189 }
00190 }
00191 fullScore += scorer.Finish();
00192
00193 ngramScore = TransformLMScore(fullScore - before_boundary);
00194 fullScore = TransformLMScore(fullScore);
00195 }
00196
00197 template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const
00198 {
00199 const lm::ngram::State &in_state = static_cast<const KenLMState&>(*ps).state;
00200
00201 std::auto_ptr<KenLMState> ret(new KenLMState());
00202
00203 if (!hypo.GetCurrTargetLength()) {
00204 ret->state = in_state;
00205 return ret.release();
00206 }
00207
00208 const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
00209
00210 const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
00211 const std::size_t adjust_end = std::min(end, begin + m_ngram->Order() - 1);
00212
00213 std::size_t position = begin;
00214 typename Model::State aux_state;
00215 typename Model::State *state0 = &ret->state, *state1 = &aux_state;
00216
00217 float score = m_ngram->Score(in_state, TranslateID(hypo.GetWord(position)), *state0);
00218 ++position;
00219 for (; position < adjust_end; ++position) {
00220 score += m_ngram->Score(*state0, TranslateID(hypo.GetWord(position)), *state1);
00221 std::swap(state0, state1);
00222 }
00223
00224 if (hypo.IsSourceCompleted()) {
00225
00226 std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
00227 const lm::WordIndex *last = LastIDs(hypo, &indices.front());
00228 score += m_ngram->FullScoreForgotState(&indices.front(), last, m_ngram->GetVocabulary().EndSentence(), ret->state).prob;
00229 } else if (adjust_end < end) {
00230
00231 std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
00232 const lm::WordIndex *last = LastIDs(hypo, &indices.front());
00233 m_ngram->GetState(&indices.front(), last, ret->state);
00234 } else if (state0 != &ret->state) {
00235
00236 ret->state = *state0;
00237 }
00238
00239 score = TransformLMScore(score);
00240
00241 if (OOVFeatureEnabled()) {
00242 std::vector<float> scores(2);
00243 scores[0] = score;
00244 scores[1] = 0.0;
00245 out->PlusEquals(this, scores);
00246 } else {
00247 out->PlusEquals(this, score);
00248 }
00249
00250 return ret.release();
00251 }
00252
00253 class LanguageModelChartStateKenLM : public FFState
00254 {
00255 public:
00256 LanguageModelChartStateKenLM() {}
00257
00258 const lm::ngram::ChartState &GetChartState() const {
00259 return m_state;
00260 }
00261 lm::ngram::ChartState &GetChartState() {
00262 return m_state;
00263 }
00264
00265 size_t hash() const {
00266 size_t ret = hash_value(m_state);
00267 return ret;
00268 }
00269 virtual bool operator==(const FFState& o) const {
00270 const LanguageModelChartStateKenLM &other = static_cast<const LanguageModelChartStateKenLM &>(o);
00271 bool ret = m_state == other.m_state;
00272 return ret;
00273 }
00274
00275 private:
00276 lm::ngram::ChartState m_state;
00277 };
00278
00279 template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *accumulator) const
00280 {
00281 LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM();
00282 lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState());
00283 const TargetPhrase &target = hypo.GetCurrTargetPhrase();
00284 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00285 target.GetAlignNonTerm().GetNonTermIndexMap();
00286
00287 const size_t size = hypo.GetCurrTargetPhrase().GetSize();
00288 size_t phrasePos = 0;
00289
00290 if (size) {
00291 const Word &word = hypo.GetCurrTargetPhrase().GetWord(0);
00292 if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
00293
00294 ruleScore.BeginSentence();
00295 phrasePos++;
00296 } else if (word.IsNonTerminal()) {
00297
00298 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
00299 const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetFFState(featureID))->GetChartState();
00300 ruleScore.BeginNonTerminal(prevState);
00301 phrasePos++;
00302 }
00303 }
00304
00305 for (; phrasePos < size; phrasePos++) {
00306 const Word &word = hypo.GetCurrTargetPhrase().GetWord(phrasePos);
00307 if (word.IsNonTerminal()) {
00308 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
00309 const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetFFState(featureID))->GetChartState();
00310 ruleScore.NonTerminal(prevState);
00311 } else {
00312 ruleScore.Terminal(TranslateID(word));
00313 }
00314 }
00315
00316 float score = ruleScore.Finish();
00317 score = TransformLMScore(score);
00318 score -= hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
00319
00320 if (OOVFeatureEnabled()) {
00321 std::vector<float> scores(2);
00322 scores[0] = score;
00323 scores[1] = 0.0;
00324 accumulator->PlusEquals(this, scores);
00325 } else {
00326 accumulator->PlusEquals(this, score);
00327 }
00328 return newState;
00329 }
00330
00331 template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const
00332 {
00333 LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM();
00334 lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState());
00335 const TargetPhrase &target = *hyperedge.label.translation;
00336 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00337 target.GetAlignNonTerm().GetNonTermIndexMap2();
00338
00339 const size_t size = target.GetSize();
00340 size_t phrasePos = 0;
00341
00342 if (size) {
00343 const Word &word = target.GetWord(0);
00344 if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
00345
00346 ruleScore.BeginSentence();
00347 phrasePos++;
00348 } else if (word.IsNonTerminal()) {
00349
00350 const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
00351 const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->states[featureID])->GetChartState();
00352 ruleScore.BeginNonTerminal(prevState);
00353 phrasePos++;
00354 }
00355 }
00356
00357 for (; phrasePos < size; phrasePos++) {
00358 const Word &word = target.GetWord(phrasePos);
00359 if (word.IsNonTerminal()) {
00360 const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
00361 const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->states[featureID])->GetChartState();
00362 ruleScore.NonTerminal(prevState);
00363 } else {
00364 ruleScore.Terminal(TranslateID(word));
00365 }
00366 }
00367
00368 float score = ruleScore.Finish();
00369 score = TransformLMScore(score);
00370 score -= target.GetScoreBreakdown().GetScoresForProducer(this)[0];
00371
00372 if (OOVFeatureEnabled()) {
00373 std::vector<float> scores(2);
00374 scores[0] = score;
00375 scores[1] = 0.0;
00376 accumulator->PlusEquals(this, scores);
00377 } else {
00378 accumulator->PlusEquals(this, score);
00379 }
00380 return newState;
00381 }
00382
00383 template <class Model> void LanguageModelKen<Model>::IncrementalCallback(Incremental::Manager &manager) const
00384 {
00385 manager.LMCallback(*m_ngram, m_lmIdLookup);
00386 }
00387
00388 template <class Model> void LanguageModelKen<Model>::ReportHistoryOrder(std::ostream &out, const Phrase &phrase) const
00389 {
00390 out << "|lm=(";
00391 if (!phrase.GetSize()) return;
00392
00393 typename Model::State aux_state;
00394 typename Model::State start_of_sentence_state = m_ngram->BeginSentenceState();
00395 typename Model::State *state0 = &start_of_sentence_state;
00396 typename Model::State *state1 = &aux_state;
00397
00398 for (std::size_t position=0; position<phrase.GetSize(); position++) {
00399 const lm::WordIndex idx = TranslateID(phrase.GetWord(position));
00400 lm::FullScoreReturn ret(m_ngram->FullScore(*state0, idx, *state1));
00401 if (position) out << ",";
00402 out << (int) ret.ngram_length << ":" << TransformLMScore(ret.prob);
00403 if (idx == 0) out << ":unk";
00404 std::swap(state0, state1);
00405 }
00406 out << ")| ";
00407 }
00408
00409 template <class Model>
00410 bool LanguageModelKen<Model>::IsUseable(const FactorMask &mask) const
00411 {
00412 bool ret = mask[m_factorType];
00413 return ret;
00414 }
00415
00416
00417
00418
00419
00420
00421
00422 template class LanguageModelKen<lm::ngram::ProbingModel>;
00423 template class LanguageModelKen<lm::ngram::RestProbingModel>;
00424 template class LanguageModelKen<lm::ngram::TrieModel>;
00425 template class LanguageModelKen<lm::ngram::ArrayTrieModel>;
00426 template class LanguageModelKen<lm::ngram::QuantTrieModel>;
00427 template class LanguageModelKen<lm::ngram::QuantArrayTrieModel>;
00428
00429
00430 LanguageModel *ConstructKenLM(const std::string &lineOrig)
00431 {
00432 FactorType factorType = 0;
00433 string filePath;
00434 util::LoadMethod load_method = util::POPULATE_OR_READ;
00435
00436 util::TokenIter<util::SingleCharacter, true> argument(lineOrig, ' ');
00437 ++argument;
00438
00439 util::StringStream line;
00440 line << "KENLM";
00441
00442 for (; argument; ++argument) {
00443 const char *equals = std::find(argument->data(), argument->data() + argument->size(), '=');
00444 UTIL_THROW_IF2(equals == argument->data() + argument->size(),
00445 "Expected = in KenLM argument " << *argument);
00446 StringPiece name(argument->data(), equals - argument->data());
00447 StringPiece value(equals + 1, argument->data() + argument->size() - equals - 1);
00448 if (name == "factor") {
00449 factorType = boost::lexical_cast<FactorType>(value);
00450 } else if (name == "order") {
00451
00452 } else if (name == "path") {
00453 filePath.assign(value.data(), value.size());
00454 } else if (name == "lazyken") {
00455
00456 if (value == "0" || value == "false") {
00457 load_method = util::POPULATE_OR_READ;
00458 } else if (value == "1" || value == "true") {
00459 load_method = util::LAZY;
00460 } else {
00461 UTIL_THROW2("Can't parse lazyken argument " << value << ". Also, lazyken is deprecated. Use load with one of the arguments lazy, populate_or_lazy, populate_or_read, read, or parallel_read.");
00462 }
00463 } else if (name == "load") {
00464 if (value == "lazy") {
00465 load_method = util::LAZY;
00466 } else if (value == "populate_or_lazy") {
00467 load_method = util::POPULATE_OR_LAZY;
00468 } else if (value == "populate_or_read" || value == "populate") {
00469 load_method = util::POPULATE_OR_READ;
00470 } else if (value == "read") {
00471 load_method = util::READ;
00472 } else if (value == "parallel_read") {
00473 load_method = util::PARALLEL_READ;
00474 } else {
00475 UTIL_THROW2("Unknown KenLM load method " << value);
00476 }
00477 } else {
00478
00479 line << " " << name << "=" << value;
00480 }
00481 }
00482
00483 return ConstructKenLM(line.str(), filePath, factorType, load_method);
00484 }
00485
00486 LanguageModel *ConstructKenLM(const std::string &line, const std::string &file, FactorType factorType, util::LoadMethod load_method)
00487 {
00488 lm::ngram::ModelType model_type;
00489 if (lm::ngram::RecognizeBinary(file.c_str(), model_type)) {
00490 switch(model_type) {
00491 case lm::ngram::PROBING:
00492 return new LanguageModelKen<lm::ngram::ProbingModel>(line, file, factorType, load_method);
00493 case lm::ngram::REST_PROBING:
00494 return new LanguageModelKen<lm::ngram::RestProbingModel>(line, file, factorType, load_method);
00495 case lm::ngram::TRIE:
00496 return new LanguageModelKen<lm::ngram::TrieModel>(line, file, factorType, load_method);
00497 case lm::ngram::QUANT_TRIE:
00498 return new LanguageModelKen<lm::ngram::QuantTrieModel>(line, file, factorType, load_method);
00499 case lm::ngram::ARRAY_TRIE:
00500 return new LanguageModelKen<lm::ngram::ArrayTrieModel>(line, file, factorType, load_method);
00501 case lm::ngram::QUANT_ARRAY_TRIE:
00502 return new LanguageModelKen<lm::ngram::QuantArrayTrieModel>(line, file, factorType, load_method);
00503 default:
00504 UTIL_THROW2("Unrecognized kenlm model type " << model_type);
00505 }
00506 } else {
00507 return new LanguageModelKen<lm::ngram::ProbingModel>(line, file, factorType, load_method);
00508 }
00509 }
00510
00511 }