00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <iostream>
00021 #include "ChartRuleLookupManagerMemory.h"
00022
00023 #include "moses/ChartParser.h"
00024 #include "moses/InputType.h"
00025 #include "moses/Terminal.h"
00026 #include "moses/ChartParserCallback.h"
00027 #include "moses/StaticData.h"
00028 #include "moses/NonTerminal.h"
00029 #include "moses/ChartCellCollection.h"
00030 #include "moses/FactorCollection.h"
00031 #include "moses/TranslationModel/PhraseDictionaryMemory.h"
00032
00033 using namespace std;
00034
00035 namespace Moses
00036 {
00037
00038 ChartRuleLookupManagerMemory::ChartRuleLookupManagerMemory(
00039 const ChartParser &parser,
00040 const ChartCellCollectionBase &cellColl,
00041 const PhraseDictionaryMemory &ruleTable)
00042 : ChartRuleLookupManagerCYKPlus(parser, cellColl)
00043 , m_ruleTable(ruleTable)
00044 , m_softMatchingMap(StaticData::Instance().GetSoftMatches())
00045 {
00046
00047 size_t sourceSize = parser.GetSize();
00048 size_t ruleLimit = parser.options()->syntax.rule_limit;
00049 m_completedRules.resize(sourceSize, CompletedRuleCollection(ruleLimit));
00050
00051 m_isSoftMatching = !m_softMatchingMap.empty();
00052 }
00053
00054 void ChartRuleLookupManagerMemory::GetChartRuleCollection(
00055 const InputPath &inputPath,
00056 size_t lastPos,
00057 ChartParserCallback &outColl)
00058 {
00059 const Range &range = inputPath.GetWordsRange();
00060 size_t startPos = range.GetStartPos();
00061 size_t absEndPos = range.GetEndPos();
00062
00063 m_lastPos = lastPos;
00064 m_stackVec.clear();
00065 m_stackScores.clear();
00066 m_outColl = &outColl;
00067 m_unaryPos = absEndPos-1;
00068
00069
00070 UpdateCompressedMatrix(startPos, absEndPos, lastPos);
00071
00072 const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode();
00073
00074
00075 if (startPos == absEndPos) {
00076 GetTerminalExtension(&rootNode, startPos);
00077 }
00078
00079 else if (absEndPos > startPos) {
00080 GetNonTerminalExtension(&rootNode, startPos);
00081 }
00082
00083
00084 CompletedRuleCollection & rules = m_completedRules[absEndPos];
00085 for (vector<CompletedRule*>::const_iterator iter = rules.begin(); iter != rules.end(); ++iter) {
00086 outColl.Add((*iter)->GetTPC(), (*iter)->GetStackVector(), range);
00087 }
00088
00089 rules.Clear();
00090
00091 }
00092
00093
00094 void ChartRuleLookupManagerMemory::UpdateCompressedMatrix(size_t startPos,
00095 size_t origEndPos,
00096 size_t lastPos)
00097 {
00098
00099 std::vector<size_t> endPosVec;
00100 size_t numNonTerms = FactorCollection::Instance().GetNumNonTerminals();
00101 m_compressedMatrixVec.resize(lastPos+1);
00102
00103
00104 if (startPos < origEndPos) {
00105 endPosVec.push_back(origEndPos-1);
00106 }
00107
00108
00109 else if (startPos == origEndPos) {
00110 startPos++;
00111 for (size_t endPos = startPos; endPos <= lastPos; endPos++) {
00112 endPosVec.push_back(endPos);
00113 }
00114
00115 for (size_t pos = startPos+1; pos <= lastPos; pos++) {
00116 CompressedMatrix & cellMatrix = m_compressedMatrixVec[pos];
00117 cellMatrix.resize(numNonTerms);
00118 for (size_t i = 0; i < numNonTerms; i++) {
00119 if (!cellMatrix[i].empty() && cellMatrix[i].back().endPos > lastPos) {
00120 cellMatrix[i].pop_back();
00121 }
00122 }
00123 }
00124 }
00125
00126 if (startPos > lastPos) {
00127 return;
00128 }
00129
00130
00131 CompressedMatrix & cellMatrix = m_compressedMatrixVec[startPos];
00132 cellMatrix.clear();
00133 cellMatrix.resize(numNonTerms);
00134 for (std::vector<size_t>::iterator p = endPosVec.begin(); p != endPosVec.end(); ++p) {
00135
00136 size_t endPos = *p;
00137
00138 const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
00139
00140 if (targetNonTerms.GetSize() == 0) {
00141 continue;
00142 }
00143
00144 #if !defined(UNLABELLED_SOURCE)
00145
00146 const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
00147
00148
00149 if (inputPath.GetNonTerminalSet().size() == 0) {
00150 continue;
00151 }
00152 #endif
00153
00154 for (size_t i = 0; i < numNonTerms; i++) {
00155 const ChartCellLabel *cellLabel = targetNonTerms.Find(i);
00156 if (cellLabel != NULL) {
00157 float score = cellLabel->GetBestScore(m_outColl);
00158 cellMatrix[i].push_back(ChartCellCache(endPos, cellLabel, score));
00159 }
00160 }
00161 }
00162 }
00163
00164
00165 void ChartRuleLookupManagerMemory::AddAndExtend(
00166 const PhraseDictionaryNodeMemory *node,
00167 size_t endPos)
00168 {
00169
00170 TargetPhraseCollection::shared_ptr tpc = node->GetTargetPhraseCollection();
00171
00172 if (!tpc->IsEmpty() && (m_stackVec.empty() || endPos != m_unaryPos)) {
00173 m_completedRules[endPos].Add(*tpc, m_stackVec, m_stackScores, *m_outColl);
00174 }
00175
00176
00177 if (endPos < m_lastPos) {
00178 if (!node->GetTerminalMap().empty()) {
00179 GetTerminalExtension(node, endPos+1);
00180 }
00181 if (!node->GetNonTerminalMap().empty()) {
00182 GetNonTerminalExtension(node, endPos+1);
00183 }
00184 }
00185 }
00186
00187
00188
00189
00190 void ChartRuleLookupManagerMemory::GetTerminalExtension(
00191 const PhraseDictionaryNodeMemory *node,
00192 size_t pos)
00193 {
00194
00195 const Word &sourceWord = GetSourceAt(pos).GetLabel();
00196 const PhraseDictionaryNodeMemory::TerminalMap & terminals = node->GetTerminalMap();
00197
00198
00199 if (terminals.size() < 5) {
00200 for (PhraseDictionaryNodeMemory::TerminalMap::const_iterator iter = terminals.begin(); iter != terminals.end(); ++iter) {
00201 const Word & word = iter->first;
00202 if (TerminalEqualityPred()(word, sourceWord)) {
00203 const PhraseDictionaryNodeMemory *child = & iter->second;
00204 AddAndExtend(child, pos);
00205 break;
00206 }
00207 }
00208 }
00209
00210 else {
00211 const PhraseDictionaryNodeMemory *child = node->GetChild(sourceWord);
00212 if (child != NULL) {
00213 AddAndExtend(child, pos);
00214 }
00215 }
00216 }
00217
00218
00219
00220 void ChartRuleLookupManagerMemory::GetNonTerminalExtension(
00221 const PhraseDictionaryNodeMemory *node,
00222 size_t startPos)
00223 {
00224
00225 const CompressedMatrix &compressedMatrix = m_compressedMatrixVec[startPos];
00226
00227
00228 const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap = node->GetNonTerminalMap();
00229
00230
00231 m_stackVec.push_back(NULL);
00232 m_stackScores.push_back(0);
00233
00234
00235 PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator p;
00236 PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator end = nonTermMap.end();
00237 for (p = nonTermMap.begin(); p != end; ++p) {
00238
00239 #if defined(UNLABELLED_SOURCE)
00240 const Word &targetNonTerm = p->first;
00241 #else
00242 const Word &targetNonTerm = p->first.second;
00243 #endif
00244 const PhraseDictionaryNodeMemory *child = &p->second;
00245
00246 if (m_isSoftMatching && !m_softMatchingMap[targetNonTerm[0]->GetId()].empty()) {
00247 const std::vector<Word>& softMatches = m_softMatchingMap[targetNonTerm[0]->GetId()];
00248 for (std::vector<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
00249 const CompressedColumn &matches = compressedMatrix[(*softMatch)[0]->GetId()];
00250 for (CompressedColumn::const_iterator match = matches.begin(); match != matches.end(); ++match) {
00251 m_stackVec.back() = match->cellLabel;
00252 m_stackScores.back() = match->score;
00253 AddAndExtend(child, match->endPos);
00254 }
00255 }
00256 }
00257
00258 const CompressedColumn &matches = compressedMatrix[targetNonTerm[0]->GetId()];
00259 for (CompressedColumn::const_iterator match = matches.begin(); match != matches.end(); ++match) {
00260 m_stackVec.back() = match->cellLabel;
00261 m_stackScores.back() = match->score;
00262 AddAndExtend(child, match->endPos);
00263 }
00264 }
00265
00266 m_stackVec.pop_back();
00267 m_stackScores.pop_back();
00268 }
00269
00270 }