00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "LoaderCompact.h"
00021
00022 #include "moses/AlignmentInfoCollection.h"
00023 #include "moses/InputFileStream.h"
00024 #include "moses/Util.h"
00025 #include "moses/Timer.h"
00026 #include "moses/Word.h"
00027 #include "Trie.h"
00028
00029 #include <istream>
00030 #include <sstream>
00031
00032 namespace Moses
00033 {
00034
00035 bool RuleTableLoaderCompact::Load(AllOptions const& opts,
00036 const std::vector<FactorType> &input,
00037 const std::vector<FactorType> &output,
00038 const std::string &inFile,
00039 size_t ,
00040 RuleTableTrie &ruleTable)
00041 {
00042 PrintUserTime("Start loading compact rule table");
00043
00044 InputFileStream inStream(inFile);
00045 LineReader reader(inStream);
00046
00047
00048 reader.ReadLine();
00049 if (reader.m_line != "1") {
00050 std::cerr << "Unexpected compact rule table format: " << reader.m_line;
00051 return false;
00052 }
00053
00054
00055 std::vector<Word> vocab;
00056 LoadVocabularySection(reader, input, vocab);
00057
00058
00059 std::vector<Phrase> sourcePhrases;
00060 std::vector<size_t> sourceLhsIds;
00061 LoadPhraseSection(reader, vocab, sourcePhrases, sourceLhsIds);
00062
00063
00064 std::vector<Phrase> targetPhrases;
00065 std::vector<size_t> targetLhsIds;
00066 LoadPhraseSection(reader, vocab, targetPhrases, targetLhsIds);
00067
00068
00069 std::vector<const AlignmentInfo *> alignmentSets;
00070 LoadAlignmentSection(reader, alignmentSets, sourcePhrases);
00071
00072
00073 if (!LoadRuleSection(reader, vocab, sourcePhrases, targetPhrases,
00074 targetLhsIds, alignmentSets,
00075 ruleTable)) {
00076 return false;
00077 }
00078
00079
00080 SortAndPrune(ruleTable);
00081
00082 return true;
00083 }
00084
00085 void RuleTableLoaderCompact::LoadVocabularySection(
00086 LineReader &reader,
00087 const std::vector<FactorType> &factorTypes,
00088 std::vector<Word> &vocabulary)
00089 {
00090
00091 reader.ReadLine();
00092 const size_t vocabSize = std::atoi(reader.m_line.c_str());
00093
00094
00095 vocabulary.resize(vocabSize);
00096 for (size_t i = 0; i < vocabSize; ++i) {
00097 reader.ReadLine();
00098 const size_t len = reader.m_line.size();
00099 bool isNonTerm = (reader.m_line[0] == '[' && reader.m_line[len-1] == ']');
00100 if (isNonTerm) {
00101 reader.m_line = reader.m_line.substr(1, len-2);
00102 }
00103 vocabulary[i].CreateFromString(Input, factorTypes, reader.m_line, isNonTerm);
00104 }
00105 }
00106
00107 void RuleTableLoaderCompact::LoadPhraseSection(
00108 LineReader &reader,
00109 const std::vector<Word> &vocab,
00110 std::vector<Phrase> &rhsPhrases,
00111 std::vector<size_t> &lhsIds)
00112 {
00113
00114 reader.ReadLine();
00115 const size_t phraseCount = std::atoi(reader.m_line.c_str());
00116
00117
00118 rhsPhrases.resize(phraseCount, Phrase(0));
00119 lhsIds.resize(phraseCount);
00120 std::vector<size_t> tokenPositions;
00121 for (size_t i = 0; i < phraseCount; ++i) {
00122 reader.ReadLine();
00123 tokenPositions.clear();
00124 FindTokens(tokenPositions, reader.m_line);
00125 const char *charLine = reader.m_line.c_str();
00126 lhsIds[i] = std::atoi(charLine+tokenPositions[0]);
00127 for (size_t j = 1; j < tokenPositions.size(); ++j) {
00128 rhsPhrases[i].AddWord(vocab[std::atoi(charLine+tokenPositions[j])]);
00129 }
00130 }
00131 }
00132
00133 void RuleTableLoaderCompact::LoadAlignmentSection(
00134 LineReader &reader, std::vector<const AlignmentInfo *> &alignmentSets, std::vector<Phrase> &sourcePhrases)
00135 {
00136
00137 reader.ReadLine();
00138 const size_t alignmentSetCount = std::atoi(reader.m_line.c_str());
00139
00140 alignmentSets.resize(alignmentSetCount * 2);
00141 AlignmentInfo::CollType alignTerm, alignNonTerm;
00142 std::vector<std::string> tokens;
00143 std::vector<size_t> points;
00144 for (size_t i = 0; i < alignmentSetCount; ++i) {
00145
00146 alignTerm.clear();
00147 alignNonTerm.clear();
00148 tokens.clear();
00149
00150 reader.ReadLine();
00151 Tokenize(tokens, reader.m_line);
00152 std::vector<std::string>::const_iterator p;
00153 for (p = tokens.begin(); p != tokens.end(); ++p) {
00154 points.clear();
00155 Tokenize<size_t>(points, *p, "-");
00156 std::pair<size_t, size_t> alignmentPair(points[0], points[1]);
00157
00158 if (sourcePhrases[i].GetWord(alignmentPair.first).IsNonTerminal()) {
00159 alignNonTerm.insert(alignmentPair);
00160 } else {
00161 alignTerm.insert(alignmentPair);
00162 }
00163
00164 }
00165 alignmentSets[i*2] = AlignmentInfoCollection::Instance().Add(alignNonTerm);
00166 alignmentSets[i*2 + 1] = AlignmentInfoCollection::Instance().Add(alignTerm);
00167 }
00168 }
00169
00170 bool RuleTableLoaderCompact::LoadRuleSection(
00171 LineReader &reader,
00172 const std::vector<Word> &vocab,
00173 const std::vector<Phrase> &sourcePhrases,
00174 const std::vector<Phrase> &targetPhrases,
00175 const std::vector<size_t> &targetLhsIds,
00176 const std::vector<const AlignmentInfo *> &alignmentSets,
00177 RuleTableTrie &ruleTable)
00178 {
00179
00180 reader.ReadLine();
00181 const size_t ruleCount = std::atoi(reader.m_line.c_str());
00182
00183
00184 const size_t numScoreComponents = ruleTable.GetNumScoreComponents();
00185 std::vector<float> scoreVector(numScoreComponents);
00186 std::vector<size_t> tokenPositions;
00187 for (size_t i = 0; i < ruleCount; ++i) {
00188 reader.ReadLine();
00189
00190 tokenPositions.clear();
00191 FindTokens(tokenPositions, reader.m_line);
00192
00193 const char *charLine = reader.m_line.c_str();
00194
00195
00196
00197 const int sourcePhraseId = std::atoi(charLine+tokenPositions[0]);
00198 const int targetPhraseId = std::atoi(charLine+tokenPositions[1]);
00199 const int alignmentSetId = std::atoi(charLine+tokenPositions[2]);
00200
00201 const Phrase &sourcePhrase = sourcePhrases[sourcePhraseId];
00202 const Phrase &targetPhrasePhrase = targetPhrases[targetPhraseId];
00203 const Word *targetLhs = new Word(vocab[targetLhsIds[targetPhraseId]]);
00204 Word sourceLHS("X");
00205 const AlignmentInfo *alignNonTerm = alignmentSets[alignmentSetId];
00206
00207
00208 for (size_t j = 0; j < numScoreComponents; ++j) {
00209 float score = std::atof(charLine+tokenPositions[3+j]);
00210 scoreVector[j] = FloorScore(TransformScore(score));
00211 }
00212 if (reader.m_line[tokenPositions[3+numScoreComponents]] != ':') {
00213 std::cerr << "Size of scoreVector != number ("
00214 << scoreVector.size() << "!=" << numScoreComponents
00215 << ") of score components on line " << reader.m_lineNum;
00216 return false;
00217 }
00218
00219
00220
00221
00222 TargetPhrase *targetPhrase = new TargetPhrase(targetPhrasePhrase, &ruleTable);
00223 targetPhrase->SetAlignNonTerm(alignNonTerm);
00224 targetPhrase->SetTargetLHS(targetLhs);
00225
00226 targetPhrase->EvaluateInIsolation(sourcePhrase, ruleTable.GetFeaturesToApply());
00227
00228
00229 TargetPhraseCollection::shared_ptr coll;
00230 coll = GetOrCreateTargetPhraseCollection(ruleTable, sourcePhrase,
00231 *targetPhrase, &sourceLHS);
00232 coll->Add(targetPhrase);
00233 }
00234
00235 return true;
00236 }
00237
00238 }