00001 #pragma once
00002
00003 #include "moses/Syntax/S2T/PChart.h"
00004
00005 namespace Moses
00006 {
00007 namespace Syntax
00008 {
00009 namespace S2T
00010 {
00011
00012 template<typename Callback>
00013 RecursiveCYKPlusParser<Callback>::RecursiveCYKPlusParser(
00014 PChart &chart,
00015 const RuleTrie &trie,
00016 std::size_t maxChartSpan)
00017 : Parser<Callback>(chart)
00018 , m_ruleTable(trie)
00019 , m_maxChartSpan(maxChartSpan)
00020 , m_callback(NULL)
00021 {
00022 m_hyperedge.head = 0;
00023 }
00024
00025 template<typename Callback>
00026 void RecursiveCYKPlusParser<Callback>::EnumerateHyperedges(
00027 const Range &range,
00028 Callback &callback)
00029 {
00030 const std::size_t start = range.GetStartPos();
00031 const std::size_t end = range.GetEndPos();
00032 m_callback = &callback;
00033 const RuleTrie::Node &rootNode = m_ruleTable.GetRootNode();
00034 m_maxEnd = std::min(Base::m_chart.GetWidth()-1, start+m_maxChartSpan-1);
00035 m_hyperedge.tail.clear();
00036
00037
00038
00039 GetTerminalExtension(rootNode, start, end);
00040
00041
00042
00043 if (end > start) {
00044 GetNonTerminalExtensions(rootNode, start, end-1, end-1);
00045 }
00046 }
00047
00048
00049
00050 template<typename Callback>
00051 void RecursiveCYKPlusParser<Callback>::GetNonTerminalExtensions(
00052 const RuleTrie::Node &node,
00053 std::size_t start,
00054 std::size_t minEnd,
00055 std::size_t maxEnd)
00056 {
00057
00058 const RuleTrie::Node::SymbolMap &nonTermMap = node.GetNonTerminalMap();
00059
00060
00061 const PChart::CompressedMatrix &matrix =
00062 Base::m_chart.GetCompressedMatrix(start);
00063
00064
00065 RuleTrie::Node::SymbolMap::const_iterator p;
00066 RuleTrie::Node::SymbolMap::const_iterator p_end = nonTermMap.end();
00067 for (p = nonTermMap.begin(); p != p_end; ++p) {
00068 const Word &nonTerm = p->first;
00069 const std::vector<PChart::CompressedItem> &items =
00070 matrix[nonTerm[0]->GetId()];
00071 for (std::vector<PChart::CompressedItem>::const_iterator q = items.begin();
00072 q != items.end(); ++q) {
00073 if (q->end >= minEnd && q->end <= maxEnd) {
00074 const RuleTrie::Node &child = p->second;
00075 AddAndExtend(child, q->end, *(q->vertex));
00076 }
00077 }
00078 }
00079 }
00080
00081
00082
00083 template<typename Callback>
00084 void RecursiveCYKPlusParser<Callback>::GetTerminalExtension(
00085 const RuleTrie::Node &node,
00086 std::size_t start,
00087 std::size_t end)
00088 {
00089
00090 const PChart::Cell::TMap &vertexMap =
00091 Base::m_chart.GetCell(start, end).terminalVertices;
00092 if (vertexMap.empty()) {
00093 return;
00094 }
00095
00096 const RuleTrie::Node::SymbolMap &terminals = node.GetTerminalMap();
00097
00098 for (PChart::Cell::TMap::const_iterator p = vertexMap.begin();
00099 p != vertexMap.end(); ++p) {
00100 const Word &terminal = p->first;
00101 const PVertex &vertex = p->second;
00102
00103
00104 if (terminals.size() < 5) {
00105 for (RuleTrie::Node::SymbolMap::const_iterator iter = terminals.begin();
00106 iter != terminals.end(); ++iter) {
00107 const Word &word = iter->first;
00108 if (word == terminal) {
00109 const RuleTrie::Node *child = & iter->second;
00110 AddAndExtend(*child, end, vertex);
00111 break;
00112 }
00113 }
00114 } else {
00115 const RuleTrie::Node *child = node.GetChild(terminal);
00116 if (child != NULL) {
00117 AddAndExtend(*child, end, vertex);
00118 }
00119 }
00120 }
00121 }
00122
00123
00124
00125 template<typename Callback>
00126 void RecursiveCYKPlusParser<Callback>::AddAndExtend(
00127 const RuleTrie::Node &node,
00128 std::size_t end,
00129 const PVertex &vertex)
00130 {
00131
00132 m_hyperedge.tail.push_back(const_cast<PVertex *>(&vertex));
00133
00134
00135 TargetPhraseCollection::shared_ptr tpc = node.GetTargetPhraseCollection();
00136 if (!tpc->IsEmpty() && !IsNonLexicalUnary(m_hyperedge)) {
00137 m_hyperedge.label.translations = tpc;
00138 (*m_callback)(m_hyperedge, end);
00139 }
00140
00141
00142
00143 if (end < m_maxEnd) {
00144 if (!node.GetTerminalMap().empty()) {
00145 for (std::size_t newEndPos = end+1; newEndPos <= m_maxEnd; newEndPos++) {
00146 GetTerminalExtension(node, end+1, newEndPos);
00147 }
00148 }
00149 if (!node.GetNonTerminalMap().empty()) {
00150 GetNonTerminalExtensions(node, end+1, end+1, m_maxEnd);
00151 }
00152 }
00153
00154 m_hyperedge.tail.pop_back();
00155 }
00156
00157 template<typename Callback>
00158 bool RecursiveCYKPlusParser<Callback>::IsNonLexicalUnary(
00159 const PHyperedge &hyperedge) const
00160 {
00161 return hyperedge.tail.size() == 1 &&
00162 hyperedge.tail[0]->symbol.IsNonTerminal();
00163 }
00164
00165 }
00166 }
00167 }