00001
00002 #pragma once
00003
00004 #include "moses/DecodeGraph.h"
00005 #include "moses/ForestInput.h"
00006 #include "moses/StaticData.h"
00007 #include "moses/Syntax/BoundedPriorityContainer.h"
00008 #include "moses/Syntax/CubeQueue.h"
00009 #include "moses/Syntax/PHyperedge.h"
00010 #include "moses/Syntax/RuleTable.h"
00011 #include "moses/Syntax/RuleTableFF.h"
00012 #include "moses/Syntax/SHyperedgeBundle.h"
00013 #include "moses/Syntax/SVertex.h"
00014 #include "moses/Syntax/SVertexRecombinationEqualityPred.h"
00015 #include "moses/Syntax/SVertexRecombinationHasher.h"
00016 #include "moses/Syntax/SymbolEqualityPred.h"
00017 #include "moses/Syntax/SymbolHasher.h"
00018 #include "moses/Syntax/T2S/InputTree.h"
00019 #include "moses/Syntax/T2S/InputTreeBuilder.h"
00020 #include "moses/Syntax/T2S/InputTreeToForest.h"
00021 #include "moses/TreeInput.h"
00022
00023 #include "DerivationWriter.h"
00024 #include "GlueRuleSynthesizer.h"
00025 #include "HyperTree.h"
00026 #include "RuleMatcherCallback.h"
00027 #include "TopologicalSorter.h"
00028
00029 namespace Moses
00030 {
00031 namespace Syntax
00032 {
00033 namespace F2S
00034 {
00035
00036 template<typename RuleMatcher>
00037 Manager<RuleMatcher>::Manager(ttasksptr const& ttask)
00038 : Syntax::Manager(ttask)
00039 {
00040 if (const ForestInput *p = dynamic_cast<const ForestInput*>(&m_source)) {
00041 m_forest = p->GetForest();
00042 m_rootVertex = p->GetRootVertex();
00043 m_sentenceLength = p->GetSize();
00044 } else if (const TreeInput *p = dynamic_cast<const TreeInput*>(&m_source)) {
00045 T2S::InputTreeBuilder builder(options()->output.factor_order);
00046 T2S::InputTree tmpTree;
00047 builder.Build(*p, "Q", tmpTree);
00048 boost::shared_ptr<Forest> forest = boost::make_shared<Forest>();
00049 m_rootVertex = T2S::InputTreeToForest(tmpTree, *forest);
00050 m_forest = forest;
00051 m_sentenceLength = p->GetSize();
00052 } else {
00053 UTIL_THROW2("ERROR: F2S::Manager requires input to be a tree or forest");
00054 }
00055 }
00056
00057 template<typename RuleMatcher>
00058 void Manager<RuleMatcher>::Decode()
00059 {
00060
00061 const std::size_t popLimit = options()->cube.pop_limit;
00062 const std::size_t ruleLimit = options()->syntax.rule_limit;
00063 const std::size_t stackLimit = options()->search.stack_size;
00064
00065
00066 InitializeStacks();
00067
00068
00069 InitializeRuleMatchers();
00070
00071
00072 RuleMatcherCallback callback(m_stackMap, ruleLimit);
00073
00074
00075 GlueRuleSynthesizer glueRuleSynthesizer(*options(), *m_glueRuleTrie);
00076
00077
00078 std::vector<const Forest::Vertex *> sortedVertices;
00079 TopologicalSorter sorter;
00080 sorter.Sort(*m_forest, sortedVertices);
00081
00082
00083 for (std::vector<const Forest::Vertex *>::const_iterator
00084 p = sortedVertices.begin(); p != sortedVertices.end(); ++p) {
00085 const Forest::Vertex &vertex = **p;
00086
00087
00088 if (vertex.incoming.empty()) {
00089 if (vertex.pvertex.span.GetStartPos() > 0 &&
00090 vertex.pvertex.span.GetEndPos() < m_sentenceLength-1 &&
00091 IsUnknownSourceWord(vertex.pvertex.symbol)) {
00092 m_oovs.insert(vertex.pvertex.symbol);
00093 }
00094 continue;
00095 }
00096
00097
00098
00099
00100
00101 callback.ClearContainer();
00102 for (typename std::vector<boost::shared_ptr<RuleMatcher> >::iterator
00103 q = m_mainRuleMatchers.begin(); q != m_mainRuleMatchers.end(); ++q) {
00104 (*q)->EnumerateHyperedges(vertex, callback);
00105 }
00106
00107
00108 const BoundedPriorityContainer<SHyperedgeBundle> &bundles =
00109 callback.GetContainer();
00110
00111
00112
00113 if (bundles.Size() == 0) {
00114 for (std::vector<Forest::Hyperedge *>::const_iterator p =
00115 vertex.incoming.begin(); p != vertex.incoming.end(); ++p) {
00116 glueRuleSynthesizer.SynthesizeRule(**p);
00117 }
00118 m_glueRuleMatcher->EnumerateHyperedges(vertex, callback);
00119
00120
00121 }
00122
00123
00124
00125 CubeQueue cubeQueue(bundles.Begin(), bundles.End());
00126 std::size_t count = 0;
00127 std::vector<SHyperedge*> buffer;
00128 while (count < popLimit && !cubeQueue.IsEmpty()) {
00129 SHyperedge *hyperedge = cubeQueue.Pop();
00130
00131
00132 hyperedge->head->pvertex = &(vertex.pvertex);
00133
00134 buffer.push_back(hyperedge);
00135 ++count;
00136 }
00137
00138
00139 SVertexStack &stack = m_stackMap[&(vertex.pvertex)];
00140 RecombineAndSort(buffer, stack);
00141
00142
00143 if (stackLimit > 0 && stack.size() > stackLimit) {
00144 stack.resize(stackLimit);
00145 }
00146 }
00147 }
00148
00149 template<typename RuleMatcher>
00150 void Manager<RuleMatcher>::InitializeRuleMatchers()
00151 {
00152 const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances();
00153 for (std::size_t i = 0; i < ffs.size(); ++i) {
00154 RuleTableFF *ff = ffs[i];
00155
00156
00157
00158
00159 const RuleTable *table = ff->GetTable();
00160 assert(table);
00161 RuleTable *nonConstTable = const_cast<RuleTable*>(table);
00162 HyperTree *trie = dynamic_cast<HyperTree*>(nonConstTable);
00163 assert(trie);
00164 boost::shared_ptr<RuleMatcher> p(new RuleMatcher(*trie));
00165 m_mainRuleMatchers.push_back(p);
00166 }
00167
00168
00169
00170
00171 m_glueRuleTrie.reset(new HyperTree(ffs[0]));
00172 m_glueRuleMatcher = boost::shared_ptr<RuleMatcher>(
00173 new RuleMatcher(*m_glueRuleTrie));
00174 }
00175
00176 template<typename RuleMatcher>
00177 void Manager<RuleMatcher>::InitializeStacks()
00178 {
00179
00180 assert(!m_forest->vertices.empty());
00181
00182 for (std::vector<Forest::Vertex *>::const_iterator
00183 p = m_forest->vertices.begin(); p != m_forest->vertices.end(); ++p) {
00184 const Forest::Vertex &vertex = **p;
00185
00186
00187 SVertexStack &stack = m_stackMap[&(vertex.pvertex)];
00188
00189
00190 if (vertex.incoming.empty()) {
00191 boost::shared_ptr<SVertex> v(new SVertex());
00192 v->best = 0;
00193 v->pvertex = &(vertex.pvertex);
00194 stack.push_back(v);
00195 }
00196 }
00197 }
00198
00199 template<typename RuleMatcher>
00200 bool Manager<RuleMatcher>::IsUnknownSourceWord(const Word &w) const
00201 {
00202 const std::size_t factorId = w[0]->GetId();
00203 const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances();
00204 for (std::size_t i = 0; i < ffs.size(); ++i) {
00205 RuleTableFF *ff = ffs[i];
00206 const boost::unordered_set<std::size_t> &sourceTerms =
00207 ff->GetSourceTerminalSet();
00208 if (sourceTerms.find(factorId) != sourceTerms.end()) {
00209 return false;
00210 }
00211 }
00212 return true;
00213 }
00214
00215 template<typename RuleMatcher>
00216 const SHyperedge *Manager<RuleMatcher>::GetBestSHyperedge() const
00217 {
00218 PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex);
00219 assert(p != m_stackMap.end());
00220 const SVertexStack &stack = p->second;
00221 assert(!stack.empty());
00222 return stack[0]->best;
00223 }
00224
00225 template<typename RuleMatcher>
00226 void Manager<RuleMatcher>::ExtractKBest(
00227 std::size_t k,
00228 std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList,
00229 bool onlyDistinct) const
00230 {
00231 kBestList.clear();
00232 if (k == 0 || m_source.GetSize() == 0) {
00233 return;
00234 }
00235
00236
00237 PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex);
00238 assert(p != m_stackMap.end());
00239 const SVertexStack &stack = p->second;
00240 assert(!stack.empty());
00241
00242 KBestExtractor extractor;
00243
00244 if (!onlyDistinct) {
00245
00246 extractor.Extract(stack, k, kBestList);
00247 return;
00248 }
00249
00250
00251
00252
00253
00254
00255 const StaticData &staticData = StaticData::Instance();
00256 const std::size_t nBestFactor = staticData.options()->nbest.factor;
00257 std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor;
00258
00259
00260 KBestExtractor::KBestVec bigList;
00261 bigList.reserve(numDerivations);
00262 extractor.Extract(stack, numDerivations, bigList);
00263
00264
00265 std::set<Phrase> distinct;
00266 for (KBestExtractor::KBestVec::const_iterator p = bigList.begin();
00267 kBestList.size() < k && p != bigList.end(); ++p) {
00268 boost::shared_ptr<KBestExtractor::Derivation> derivation = *p;
00269 Phrase translation = KBestExtractor::GetOutputPhrase(*derivation);
00270 if (distinct.insert(translation).second) {
00271 kBestList.push_back(derivation);
00272 }
00273 }
00274 }
00275
00276
00277
00278 template<typename RuleMatcher>
00279 void Manager<RuleMatcher>::RecombineAndSort(
00280 const std::vector<SHyperedge*> &buffer, SVertexStack &stack)
00281 {
00282
00283
00284
00285
00286
00287 typedef boost::unordered_map<SVertex *, SVertex *,
00288 SVertexRecombinationHasher,
00289 SVertexRecombinationEqualityPred> Map;
00290 Map map;
00291 for (std::vector<SHyperedge*>::const_iterator p = buffer.begin();
00292 p != buffer.end(); ++p) {
00293 SHyperedge *h = *p;
00294 SVertex *v = h->head;
00295 assert(v->best == h);
00296 assert(v->recombined.empty());
00297 std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v));
00298 if (result.second) {
00299 continue;
00300 }
00301
00302
00303
00304 SVertex *storedVertex = result.first->second;
00305 if (h->label.futureScore > storedVertex->best->label.futureScore) {
00306
00307 storedVertex->recombined.push_back(storedVertex->best);
00308 storedVertex->best = h;
00309 } else {
00310 storedVertex->recombined.push_back(h);
00311 }
00312 h->head->best = 0;
00313 delete h->head;
00314 h->head = storedVertex;
00315 }
00316
00317
00318 stack.clear();
00319 stack.reserve(map.size());
00320 for (Map::const_iterator p = map.begin(); p != map.end(); ++p) {
00321 stack.push_back(boost::shared_ptr<SVertex>(p->first));
00322 }
00323
00324
00325 std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer());
00326 }
00327
00328 template<typename RuleMatcher>
00329 void Manager<RuleMatcher>::OutputDetailedTranslationReport(
00330 OutputCollector *collector) const
00331 {
00332 const SHyperedge *best = GetBestSHyperedge();
00333 if (best == NULL || collector == NULL) {
00334 return;
00335 }
00336 long translationId = m_source.GetTranslationId();
00337 std::ostringstream out;
00338 DerivationWriter::Write(*best, translationId, out);
00339 collector->Write(translationId, out.str());
00340 }
00341
00342 }
00343 }
00344 }