00001 #include "moses/LM/oxlm/SourceOxLM.h"
00002
00003 #include <boost/archive/binary_iarchive.hpp>
00004 #include <boost/archive/binary_oarchive.hpp>
00005 #include <boost/filesystem.hpp>
00006 #include "moses/TypeDef.h"
00007 #include "moses/TranslationTask.h"
00008
00009 using namespace std;
00010 using namespace oxlm;
00011
00012 namespace Moses
00013 {
00014
00015 SourceOxLM::SourceOxLM(const string &line)
00016 : BilingualLM(line), posBackOff(false), posFactorType(1),
00017 persistentCache(false), cacheHits(0), totalHits(0)
00018 {
00019 FactorCollection& factorFactory = FactorCollection::Instance();
00020 const Factor* NULL_factor = factorFactory.AddFactor("<unk>");
00021 NULL_word.SetFactor(0, NULL_factor);
00022 }
00023
00024 SourceOxLM::~SourceOxLM()
00025 {
00026 if (persistentCache) {
00027 double cache_hit_ratio = 100.0 * cacheHits / totalHits;
00028 cerr << "Cache hit ratio: " << cache_hit_ratio << endl;
00029 }
00030 }
00031
00032 float SourceOxLM::Score(
00033 vector<int>& source_words,
00034 vector<int>& target_words) const
00035 {
00036
00037
00038
00039
00040 vector<int> context = target_words;
00041 int word = context.back();
00042 context.pop_back();
00043 reverse(context.begin(), context.end());
00044 context.insert(context.end(), source_words.begin(), source_words.end());
00045
00046 float score;
00047 if (persistentCache) {
00048 if (!cache.get()) {
00049 cache.reset(new QueryCache());
00050 }
00051
00052 ++totalHits;
00053 NGram query(word, context);
00054 pair<double, bool> ret = cache->get(query);
00055 if (ret.second) {
00056 score = ret.first;
00057 ++cacheHits;
00058 } else {
00059 score = model.getLogProb(word, context);
00060 cache->put(query, score);
00061 }
00062 } else {
00063 score = model.getLogProb(word, context);
00064 }
00065
00066
00067 return score;
00068 }
00069
00070 int SourceOxLM::getNeuralLMId(const Word& word, bool is_source_word) const
00071 {
00072 return is_source_word ? mapper->convertSource(word) : mapper->convert(word);
00073 }
00074
00075 const Word& SourceOxLM::getNullWord() const
00076 {
00077 return NULL_word;
00078 }
00079
00080 void SourceOxLM::loadModel()
00081 {
00082 model.load(m_filePath);
00083
00084 boost::shared_ptr<ModelData> config = model.getConfig();
00085 source_ngrams = 2 * config->source_order - 1;
00086 target_ngrams = config->ngram_order - 1;
00087
00088 boost::shared_ptr<Vocabulary> vocab = model.getVocab();
00089 mapper = boost::make_shared<OxLMParallelMapper>(
00090 vocab, posBackOff, posFactorType);
00091 }
00092
00093 void SourceOxLM::SetParameter(const string& key, const string& value)
00094 {
00095 if (key == "persistent-cache") {
00096 persistentCache = Scan<bool>(value);
00097 } else if (key == "pos-back-off") {
00098 posBackOff = Scan<bool>(value);
00099 } else if (key == "pos-factor-type") {
00100 posFactorType = Scan<FactorType>(value);
00101 } else {
00102 BilingualLM::SetParameter(key, value);
00103 }
00104 }
00105
00106 void SourceOxLM::InitializeForInput(ttasksptr const& ttask)
00107 {
00108 const InputType& source = *ttask->GetSource();
00109 BilingualLM::InitializeForInput(ttask);
00110
00111 if (persistentCache) {
00112 if (!cache.get()) {
00113 cache.reset(new QueryCache());
00114 }
00115
00116 int sentence_id = source.GetTranslationId();
00117 string cacheFile = m_filePath + "." + to_string(sentence_id) + ".cache.bin";
00118 if (boost::filesystem::exists(cacheFile)) {
00119 ifstream fin(cacheFile);
00120 boost::archive::binary_iarchive iar(fin);
00121 cerr << "Loading n-gram probability cache from " << cacheFile << endl;
00122 iar >> *cache;
00123 cerr << "Done loading " << cache->size()
00124 << " n-gram probabilities..." << endl;
00125 } else {
00126 cerr << "Cache file not found!" << endl;
00127 }
00128 }
00129 }
00130
00131 void SourceOxLM::CleanUpAfterSentenceProcessing(const InputType& source)
00132 {
00133
00134 model.clearCache();
00135
00136 if (persistentCache) {
00137 int sentence_id = source.GetTranslationId();
00138 string cacheFile = m_filePath + "." + to_string(sentence_id) + ".cache.bin";
00139 ofstream fout(cacheFile);
00140 boost::archive::binary_oarchive oar(fout);
00141 cerr << "Saving persistent cache to " << cacheFile << endl;
00142 oar << *cache;
00143 cerr << "Done saving " << cache->size()
00144 << " n-gram probabilities..." << endl;
00145
00146 cache->clear();
00147 }
00148 }
00149
00150 }