00001 #pragma once
00002
00003 #include "moses/DecodeGraph.h"
00004 #include "moses/StaticData.h"
00005 #include "moses/Syntax/BoundedPriorityContainer.h"
00006 #include "moses/Syntax/CubeQueue.h"
00007 #include "moses/Syntax/F2S/DerivationWriter.h"
00008 #include "moses/Syntax/F2S/RuleMatcherCallback.h"
00009 #include "moses/Syntax/PHyperedge.h"
00010 #include "moses/Syntax/RuleTable.h"
00011 #include "moses/Syntax/RuleTableFF.h"
00012 #include "moses/Syntax/SHyperedgeBundle.h"
00013 #include "moses/Syntax/SVertex.h"
00014 #include "moses/Syntax/SVertexRecombinationEqualityPred.h"
00015 #include "moses/Syntax/SVertexRecombinationHasher.h"
00016 #include "moses/Syntax/SymbolEqualityPred.h"
00017 #include "moses/Syntax/SymbolHasher.h"
00018
00019 #include "GlueRuleSynthesizer.h"
00020 #include "InputTreeBuilder.h"
00021 #include "RuleTrie.h"
00022
00023 namespace Moses
00024 {
00025 namespace Syntax
00026 {
00027 namespace T2S
00028 {
00029
00030 template<typename RuleMatcher>
00031 Manager<RuleMatcher>::Manager(ttasksptr const& ttask)
00032 : Syntax::Manager(ttask)
00033 {
00034 if (const TreeInput *p = dynamic_cast<const TreeInput*>(&m_source)) {
00035
00036 InputTreeBuilder builder(options()->output.factor_order);
00037 builder.Build(*p, "Q", m_inputTree);
00038 } else {
00039 UTIL_THROW2("ERROR: T2S::Manager requires input to be a tree");
00040 }
00041 }
00042
00043 template<typename RuleMatcher>
00044 void Manager<RuleMatcher>::InitializeRuleMatchers()
00045 {
00046 const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances();
00047 for (std::size_t i = 0; i < ffs.size(); ++i) {
00048 RuleTableFF *ff = ffs[i];
00049
00050
00051
00052
00053 const RuleTable *table = ff->GetTable();
00054 assert(table);
00055 RuleTable *nonConstTable = const_cast<RuleTable*>(table);
00056 RuleTrie *trie = dynamic_cast<RuleTrie*>(nonConstTable);
00057 assert(trie);
00058 boost::shared_ptr<RuleMatcher> p(new RuleMatcher(m_inputTree, *trie));
00059 m_ruleMatchers.push_back(p);
00060 }
00061
00062
00063
00064
00065 m_glueRuleTrie.reset(new RuleTrie(ffs[0]));
00066 boost::shared_ptr<RuleMatcher> p(new RuleMatcher(m_inputTree, *m_glueRuleTrie));
00067 m_ruleMatchers.push_back(p);
00068 m_glueRuleMatcher = p.get();
00069 }
00070
00071 template<typename RuleMatcher>
00072 void Manager<RuleMatcher>::InitializeStacks()
00073 {
00074
00075 assert(!m_inputTree.nodes.empty());
00076
00077 for (std::vector<InputTree::Node>::const_iterator p =
00078 m_inputTree.nodes.begin(); p != m_inputTree.nodes.end(); ++p) {
00079 const InputTree::Node &node = *p;
00080
00081
00082 SVertexStack &stack = m_stackMap[&(node.pvertex)];
00083
00084
00085 if (node.children.empty()) {
00086 boost::shared_ptr<SVertex> v(new SVertex());
00087 v->best = 0;
00088 v->pvertex = &(node.pvertex);
00089 stack.push_back(v);
00090 }
00091 }
00092 }
00093
00094 template<typename RuleMatcher>
00095 void Manager<RuleMatcher>::Decode()
00096 {
00097
00098
00099
00100 const std::size_t popLimit = this->options()->cube.pop_limit;
00101 const std::size_t ruleLimit = this->options()->syntax.rule_limit;
00102 const std::size_t stackLimit = this->options()->search.stack_size;
00103
00104
00105 InitializeStacks();
00106
00107
00108 InitializeRuleMatchers();
00109
00110
00111 F2S::RuleMatcherCallback callback(m_stackMap, ruleLimit);
00112
00113
00114 Word dflt_nonterm = options()->syntax.output_default_non_terminal;
00115 GlueRuleSynthesizer glueRuleSynthesizer(*m_glueRuleTrie, dflt_nonterm);
00116
00117
00118 for (std::vector<InputTree::Node>::const_iterator p =
00119 m_inputTree.nodes.begin(); p != m_inputTree.nodes.end(); ++p) {
00120
00121 const InputTree::Node &node = *p;
00122
00123
00124 if (node.children.empty()) {
00125 continue;
00126 }
00127
00128
00129
00130
00131
00132 callback.ClearContainer();
00133 for (typename std::vector<boost::shared_ptr<RuleMatcher> >::iterator
00134 q = m_ruleMatchers.begin(); q != m_ruleMatchers.end(); ++q) {
00135 (*q)->EnumerateHyperedges(node, callback);
00136 }
00137
00138
00139 const BoundedPriorityContainer<SHyperedgeBundle> &bundles =
00140 callback.GetContainer();
00141
00142
00143
00144 if (bundles.Size() == 0) {
00145 glueRuleSynthesizer.SynthesizeRule(node);
00146 m_glueRuleMatcher->EnumerateHyperedges(node, callback);
00147 assert(bundles.Size() == 1);
00148 }
00149
00150
00151
00152 CubeQueue cubeQueue(bundles.Begin(), bundles.End());
00153 std::size_t count = 0;
00154 std::vector<SHyperedge*> buffer;
00155 while (count < popLimit && !cubeQueue.IsEmpty()) {
00156 SHyperedge *hyperedge = cubeQueue.Pop();
00157
00158
00159 hyperedge->head->pvertex = &(node.pvertex);
00160
00161 buffer.push_back(hyperedge);
00162 ++count;
00163 }
00164
00165
00166 SVertexStack &stack = m_stackMap[&(node.pvertex)];
00167 RecombineAndSort(buffer, stack);
00168
00169
00170 if (stackLimit > 0 && stack.size() > stackLimit) {
00171 stack.resize(stackLimit);
00172 }
00173 }
00174 }
00175
00176 template<typename RuleMatcher>
00177 const SHyperedge *Manager<RuleMatcher>::GetBestSHyperedge() const
00178 {
00179 const InputTree::Node &rootNode = m_inputTree.nodes.back();
00180 F2S::PVertexToStackMap::const_iterator p = m_stackMap.find(&rootNode.pvertex);
00181 assert(p != m_stackMap.end());
00182 const SVertexStack &stack = p->second;
00183 assert(!stack.empty());
00184 return stack[0]->best;
00185 }
00186
00187 template<typename RuleMatcher>
00188 void Manager<RuleMatcher>::ExtractKBest(
00189 std::size_t k,
00190 std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList,
00191 bool onlyDistinct) const
00192 {
00193 kBestList.clear();
00194 if (k == 0 || m_source.GetSize() == 0) {
00195 return;
00196 }
00197
00198
00199 const InputTree::Node &rootNode = m_inputTree.nodes.back();
00200 F2S::PVertexToStackMap::const_iterator p = m_stackMap.find(&rootNode.pvertex);
00201 assert(p != m_stackMap.end());
00202 const SVertexStack &stack = p->second;
00203 assert(!stack.empty());
00204
00205 KBestExtractor extractor;
00206
00207 if (!onlyDistinct) {
00208
00209 extractor.Extract(stack, k, kBestList);
00210 return;
00211 }
00212
00213
00214
00215
00216
00217
00218
00219 const std::size_t nBestFactor = this->options()->nbest.factor;
00220 std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor;
00221
00222
00223 KBestExtractor::KBestVec bigList;
00224 bigList.reserve(numDerivations);
00225 extractor.Extract(stack, numDerivations, bigList);
00226
00227
00228 std::set<Phrase> distinct;
00229 for (KBestExtractor::KBestVec::const_iterator p = bigList.begin();
00230 kBestList.size() < k && p != bigList.end(); ++p) {
00231 boost::shared_ptr<KBestExtractor::Derivation> derivation = *p;
00232 Phrase translation = KBestExtractor::GetOutputPhrase(*derivation);
00233 if (distinct.insert(translation).second) {
00234 kBestList.push_back(derivation);
00235 }
00236 }
00237 }
00238
00239
00240
00241 template<typename RuleMatcher>
00242 void Manager<RuleMatcher>::RecombineAndSort(
00243 const std::vector<SHyperedge*> &buffer, SVertexStack &stack)
00244 {
00245
00246
00247
00248
00249
00250 typedef boost::unordered_map<SVertex *, SVertex *,
00251 SVertexRecombinationHasher,
00252 SVertexRecombinationEqualityPred> Map;
00253 Map map;
00254 for (std::vector<SHyperedge*>::const_iterator p = buffer.begin();
00255 p != buffer.end(); ++p) {
00256 SHyperedge *h = *p;
00257 SVertex *v = h->head;
00258 assert(v->best == h);
00259 assert(v->recombined.empty());
00260 std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v));
00261 if (result.second) {
00262 continue;
00263 }
00264
00265
00266
00267 SVertex *storedVertex = result.first->second;
00268 if (h->label.futureScore > storedVertex->best->label.futureScore) {
00269
00270 storedVertex->recombined.push_back(storedVertex->best);
00271 storedVertex->best = h;
00272 } else {
00273 storedVertex->recombined.push_back(h);
00274 }
00275 h->head->best = 0;
00276 delete h->head;
00277 h->head = storedVertex;
00278 }
00279
00280
00281 stack.clear();
00282 stack.reserve(map.size());
00283 for (Map::const_iterator p = map.begin(); p != map.end(); ++p) {
00284 stack.push_back(boost::shared_ptr<SVertex>(p->first));
00285 }
00286
00287
00288 std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer());
00289 }
00290
00291 template<typename RuleMatcher>
00292 void Manager<RuleMatcher>::OutputDetailedTranslationReport(
00293 OutputCollector *collector) const
00294 {
00295 const SHyperedge *best = GetBestSHyperedge();
00296 if (best == NULL || collector == NULL) {
00297 return;
00298 }
00299 long translationId = m_source.GetTranslationId();
00300 std::ostringstream out;
00301 F2S::DerivationWriter::Write(*best, translationId, out);
00302 collector->Write(translationId, out.str());
00303 }
00304
00305 }
00306 }
00307 }