00001 #include "GlobalLexicalModelUnlimited.h"
00002 #include <fstream>
00003 #include "moses/StaticData.h"
00004 #include "moses/InputFileStream.h"
00005 #include "moses/Hypothesis.h"
00006 #include "moses/TranslationTask.h"
00007 #include "util/string_piece_hash.hh"
00008 #include "util/string_stream.hh"
00009
00010 using namespace std;
00011
00012 namespace Moses
00013 {
00014 GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const std::string &line)
00015 :StatelessFeatureFunction(0, line)
00016 {
00017 UTIL_THROW(util::Exception,
00018 "GlobalLexicalModelUnlimited hasn't been refactored for new feature function framework yet");
00019
00020 const vector<string> modelSpec = Tokenize(line);
00021
00022 for (size_t i = 0; i < modelSpec.size(); i++ ) {
00023 bool ignorePunctuation = true, biasFeature = false, restricted = false;
00024 size_t context = 0;
00025 string filenameSource, filenameTarget;
00026 vector< string > factors;
00027 vector< string > spec = Tokenize(modelSpec[i]," ");
00028
00029
00030 if (spec.size() > 0) {
00031 if (spec.size() != 2 && spec.size() != 3 && spec.size() != 4 && spec.size() != 6) {
00032 std::cerr << "Format of glm feature is <factor-src>-<factor-tgt> [ignore-punct] [use-bias] "
00033 << "[context-type] [filename-src filename-tgt]";
00034
00035 }
00036
00037 factors = Tokenize(spec[0],"-");
00038 if (spec.size() >= 2)
00039 ignorePunctuation = Scan<size_t>(spec[1]);
00040 if (spec.size() >= 3)
00041 biasFeature = Scan<size_t>(spec[2]);
00042 if (spec.size() >= 4)
00043 context = Scan<size_t>(spec[3]);
00044 if (spec.size() == 6) {
00045 filenameSource = spec[4];
00046 filenameTarget = spec[5];
00047 restricted = true;
00048 }
00049 } else
00050 factors = Tokenize(modelSpec[i],"-");
00051
00052 if ( factors.size() != 2 ) {
00053 std::cerr << "Wrong factor definition for global lexical model unlimited: " << modelSpec[i];
00054
00055 }
00056
00057 const vector<FactorType> inputFactors = Tokenize<FactorType>(factors[0],",");
00058 const vector<FactorType> outputFactors = Tokenize<FactorType>(factors[1],",");
00059 throw runtime_error("GlobalLexicalModelUnlimited should be reimplemented as a stateful feature");
00060 GlobalLexicalModelUnlimited* glmu = NULL;
00061
00062 if (restricted) {
00063 cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
00064 if (!glmu->Load(filenameSource, filenameTarget)) {
00065 std::cerr << "Unable to load word lists for word translation feature from files "
00066 << filenameSource
00067 << " and "
00068 << filenameTarget;
00069
00070 }
00071 }
00072 }
00073 }
00074
00075 bool GlobalLexicalModelUnlimited::Load(const std::string &filePathSource,
00076 const std::string &filePathTarget)
00077 {
00078
00079 ifstream inFileSource(filePathSource.c_str());
00080 if (!inFileSource) {
00081 cerr << "could not open file " << filePathSource << endl;
00082 return false;
00083 }
00084
00085 std::string line;
00086 while (getline(inFileSource, line)) {
00087 m_vocabSource.insert(line);
00088 }
00089
00090 inFileSource.close();
00091
00092
00093 ifstream inFileTarget(filePathTarget.c_str());
00094 if (!inFileTarget) {
00095 cerr << "could not open file " << filePathTarget << endl;
00096 return false;
00097 }
00098
00099 while (getline(inFileTarget, line)) {
00100 m_vocabTarget.insert(line);
00101 }
00102
00103 inFileTarget.close();
00104
00105 m_unrestricted = false;
00106 return true;
00107 }
00108
00109 void GlobalLexicalModelUnlimited::InitializeForInput(ttasksptr const& ttask)
00110 {
00111 UTIL_THROW_IF2(ttask->GetSource()->GetType() != SentenceInput,
00112 "GlobalLexicalModel works only with sentence input.");
00113 Sentence const* s = reinterpret_cast<Sentence const*>(ttask->GetSource().get());
00114 m_local.reset(new ThreadLocalStorage);
00115 m_local->input = s;
00116 }
00117
00118 void GlobalLexicalModelUnlimited::EvaluateWhenApplied(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const
00119 {
00120 const Sentence& input = *(m_local->input);
00121 const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();
00122
00123 for(size_t targetIndex = 0; targetIndex < targetPhrase.GetSize(); targetIndex++ ) {
00124 StringPiece targetString = targetPhrase.GetWord(targetIndex).GetString(0);
00125
00126 if (m_ignorePunctuation) {
00127
00128 char firstChar = targetString[0];
00129 CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
00130 if(charIterator != m_punctuationHash.end())
00131 continue;
00132 }
00133
00134 if (m_biasFeature) {
00135 util::StringStream feature;
00136 feature << "glm_";
00137 feature << targetString;
00138 feature << "~";
00139 feature << "**BIAS**";
00140 accumulator->SparsePlusEquals(feature.str(), 1);
00141 }
00142
00143 boost::unordered_set<uint64_t> alreadyScored;
00144 for(size_t sourceIndex = 0; sourceIndex < input.GetSize(); sourceIndex++ ) {
00145 const StringPiece sourceString = input.GetWord(sourceIndex).GetString(0);
00146
00147
00148 if (m_ignorePunctuation) {
00149
00150 char firstChar = sourceString[0];
00151 CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
00152 if(charIterator != m_punctuationHash.end())
00153 continue;
00154 }
00155 const uint64_t sourceHash = util::MurmurHashNative(sourceString.data(), sourceString.size());
00156
00157 if ( alreadyScored.find(sourceHash) == alreadyScored.end()) {
00158 bool sourceExists, targetExists;
00159 if (!m_unrestricted) {
00160 sourceExists = FindStringPiece(m_vocabSource, sourceString ) != m_vocabSource.end();
00161 targetExists = FindStringPiece(m_vocabTarget, targetString) != m_vocabTarget.end();
00162 }
00163
00164
00165 if (m_unrestricted || (sourceExists && targetExists)) {
00166 if (m_sourceContext) {
00167 if (sourceIndex == 0) {
00168
00169 util::StringStream feature;
00170 feature << "glm_";
00171 feature << targetString;
00172 feature << "~";
00173 feature << "<s>,";
00174 feature << sourceString;
00175 accumulator->SparsePlusEquals(feature.str(), 1);
00176 alreadyScored.insert(sourceHash);
00177 }
00178
00179
00180 for(int contextIndex = sourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) {
00181 StringPiece contextString = input.GetWord(contextIndex).GetString(0);
00182 bool contextExists;
00183 if (!m_unrestricted)
00184 contextExists = FindStringPiece(m_vocabSource, contextString ) != m_vocabSource.end();
00185
00186 if (m_unrestricted || contextExists) {
00187 util::StringStream feature;
00188 feature << "glm_";
00189 feature << targetString;
00190 feature << "~";
00191 feature << sourceString;
00192 feature << ",";
00193 feature << contextString;
00194 accumulator->SparsePlusEquals(feature.str(), 1);
00195 alreadyScored.insert(sourceHash);
00196 }
00197 }
00198 } else if (m_biphrase) {
00199
00200 int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;
00201
00202
00203 StringPiece targetContext;
00204 if (globalTargetIndex > 0)
00205 targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetString(0);
00206 else
00207 targetContext = "<s>";
00208
00209 if (sourceIndex == 0) {
00210 StringPiece sourceTrigger = "<s>";
00211 AddFeature(accumulator, sourceTrigger, sourceString,
00212 targetContext, targetString);
00213 } else
00214 for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) {
00215 StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0);
00216 bool sourceTriggerExists = false;
00217 if (!m_unrestricted)
00218 sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();
00219
00220 if (m_unrestricted || sourceTriggerExists)
00221 AddFeature(accumulator, sourceTrigger, sourceString,
00222 targetContext, targetString);
00223 }
00224
00225
00226 StringPiece sourceContext;
00227 if (sourceIndex-1 >= 0)
00228 sourceContext = input.GetWord(sourceIndex-1).GetString(0);
00229 else
00230 sourceContext = "<s>";
00231
00232 if (globalTargetIndex == 0) {
00233 string targetTrigger = "<s>";
00234 AddFeature(accumulator, sourceContext, sourceString,
00235 targetTrigger, targetString);
00236 } else
00237 for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
00238 StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0);
00239 bool targetTriggerExists = false;
00240 if (!m_unrestricted)
00241 targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();
00242
00243 if (m_unrestricted || targetTriggerExists)
00244 AddFeature(accumulator, sourceContext, sourceString,
00245 targetTrigger, targetString);
00246 }
00247 } else if (m_bitrigger) {
00248
00249 int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;
00250
00251 if (sourceIndex == 0) {
00252 StringPiece sourceTrigger = "<s>";
00253 bool sourceTriggerExists = true;
00254
00255 if (globalTargetIndex == 0) {
00256 string targetTrigger = "<s>";
00257 bool targetTriggerExists = true;
00258
00259 if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
00260 AddFeature(accumulator, sourceTrigger, sourceString,
00261 targetTrigger, targetString);
00262 } else {
00263
00264 for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
00265 StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0);
00266 bool targetTriggerExists = false;
00267 if (!m_unrestricted)
00268 targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();
00269
00270 if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
00271 AddFeature(accumulator, sourceTrigger, sourceString,
00272 targetTrigger, targetString);
00273 }
00274 }
00275 }
00276
00277 else {
00278
00279 for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) {
00280 StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0);
00281 bool sourceTriggerExists = false;
00282 if (!m_unrestricted)
00283 sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();
00284
00285 if (globalTargetIndex == 0) {
00286 string targetTrigger = "<s>";
00287 bool targetTriggerExists = true;
00288
00289 if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
00290 AddFeature(accumulator, sourceTrigger, sourceString,
00291 targetTrigger, targetString);
00292 } else {
00293
00294 for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
00295 StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0);
00296 bool targetTriggerExists = false;
00297 if (!m_unrestricted)
00298 targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();
00299
00300 if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
00301 AddFeature(accumulator, sourceTrigger, sourceString,
00302 targetTrigger, targetString);
00303 }
00304 }
00305 }
00306 }
00307 } else {
00308 util::StringStream feature;
00309 feature << "glm_";
00310 feature << targetString;
00311 feature << "~";
00312 feature << sourceString;
00313 accumulator->SparsePlusEquals(feature.str(), 1);
00314 alreadyScored.insert(sourceHash);
00315
00316 }
00317 }
00318 }
00319 }
00320 }
00321 }
00322
00323 void GlobalLexicalModelUnlimited::AddFeature(ScoreComponentCollection* accumulator,
00324 StringPiece sourceTrigger, StringPiece sourceWord,
00325 StringPiece targetTrigger, StringPiece targetWord) const
00326 {
00327 util::StringStream feature;
00328 feature << "glm_";
00329 feature << targetTrigger;
00330 feature << ",";
00331 feature << targetWord;
00332 feature << "~";
00333 feature << sourceTrigger;
00334 feature << ",";
00335 feature << sourceWord;
00336 accumulator->SparsePlusEquals(feature.str(), 1);
00337
00338 }
00339
00340 }