00001 #pragma once
00002
00003 namespace Moses
00004 {
00005 namespace Syntax
00006 {
00007 namespace F2S
00008 {
00009
00010 template<typename Callback>
00011 RuleMatcherHyperTree<Callback>::RuleMatcherHyperTree(const HyperTree &ruleTrie)
00012 : m_ruleTrie(ruleTrie)
00013 {
00014 }
00015
00016 template<typename Callback>
00017 void RuleMatcherHyperTree<Callback>::EnumerateHyperedges(
00018 const Forest::Vertex &v, Callback &callback)
00019 {
00020 const HyperTree::Node &root = m_ruleTrie.GetRootNode();
00021 HyperPath::NodeSeq nodeSeq(1, v.pvertex.symbol[0]->GetId());
00022 const HyperTree::Node *child = root.GetChild(nodeSeq);
00023 if (!child) {
00024 return;
00025 }
00026
00027 m_hyperedge.head = const_cast<PVertex*>(&v.pvertex);
00028
00029
00030 MatchItem item;
00031 item.annotatedFNS.fns = FNS(1, &v);
00032 item.trieNode = child;
00033 m_queue.push(item);
00034
00035 while (!m_queue.empty()) {
00036 MatchItem item = m_queue.front();
00037 m_queue.pop();
00038 if (item.trieNode->HasRules()) {
00039 const FNS &fns = item.annotatedFNS.fns;
00040
00041 m_hyperedge.tail.clear();
00042 for (FNS::const_iterator p = fns.begin(); p != fns.end(); ++p) {
00043 const Forest::Vertex *v = *p;
00044 m_hyperedge.tail.push_back(const_cast<PVertex *>(&(v->pvertex)));
00045 }
00046
00047 m_hyperedge.label.inputWeight = 0.0f;
00048 for (std::vector<const Forest::Hyperedge *>::const_iterator
00049 p = item.annotatedFNS.fragment.begin();
00050 p != item.annotatedFNS.fragment.end(); ++p) {
00051 m_hyperedge.label.inputWeight += (*p)->weight;
00052 }
00053
00054 m_hyperedge.label.translations
00055 = item.trieNode->GetTargetPhraseCollection();
00056
00057 callback(m_hyperedge);
00058 }
00059 PropagateNextLexel(item);
00060 }
00061 }
00062
00063 template<typename Callback>
00064 void RuleMatcherHyperTree<Callback>::PropagateNextLexel(const MatchItem &item)
00065 {
00066 std::vector<AnnotatedFNS> tfns;
00067 std::vector<AnnotatedFNS> rfns;
00068 std::vector<AnnotatedFNS> rfns2;
00069
00070 const HyperTree::Node &trieNode = *(item.trieNode);
00071 const HyperTree::Node::Map &map = trieNode.GetMap();
00072
00073 for (HyperTree::Node::Map::const_iterator p = map.begin();
00074 p != map.end(); ++p) {
00075 const HyperPath::NodeSeq &edgeLabel = p->first;
00076 const HyperTree::Node &child = p->second;
00077
00078 const int numSubSeqs = CountCommas(edgeLabel) + 1;
00079
00080 std::size_t pos = 0;
00081 for (int i = 0; i < numSubSeqs; ++i) {
00082 const FNS &fns = item.annotatedFNS.fns;
00083 tfns.clear();
00084 if (edgeLabel[pos] == HyperPath::kEpsilon) {
00085 AnnotatedFNS x;
00086 x.fns = FNS(1, fns[i]);
00087 tfns.push_back(x);
00088 pos += 2;
00089 } else {
00090 const int subSeqLength = SubSeqLength(edgeLabel, pos);
00091 const std::vector<Forest::Hyperedge*> &incoming = fns[i]->incoming;
00092 for (std::vector<Forest::Hyperedge *>::const_iterator q =
00093 incoming.begin(); q != incoming.end(); ++q) {
00094 const Forest::Hyperedge &edge = **q;
00095 if (MatchChildren(edge.tail, edgeLabel, pos, subSeqLength)) {
00096 tfns.resize(tfns.size()+1);
00097 tfns.back().fns.assign(edge.tail.begin(), edge.tail.end());
00098 tfns.back().fragment.push_back(&edge);
00099 }
00100 }
00101 pos += subSeqLength + 1;
00102 }
00103 if (tfns.empty()) {
00104 rfns.clear();
00105 break;
00106 } else if (i == 0) {
00107 rfns.swap(tfns);
00108 } else {
00109 CartesianProduct(rfns, tfns, rfns2);
00110 rfns.swap(rfns2);
00111 }
00112 }
00113
00114 for (typename std::vector<AnnotatedFNS>::const_iterator q = rfns.begin();
00115 q != rfns.end(); ++q) {
00116 MatchItem newItem;
00117 newItem.annotatedFNS.fns = q->fns;
00118 newItem.annotatedFNS.fragment = item.annotatedFNS.fragment;
00119 newItem.annotatedFNS.fragment.insert(newItem.annotatedFNS.fragment.end(),
00120 q->fragment.begin(),
00121 q->fragment.end());
00122 newItem.trieNode = &child;
00123 m_queue.push(newItem);
00124 }
00125 }
00126 }
00127
00128 template<typename Callback>
00129 void RuleMatcherHyperTree<Callback>::CartesianProduct(
00130 const std::vector<AnnotatedFNS> &x,
00131 const std::vector<AnnotatedFNS> &y,
00132 std::vector<AnnotatedFNS> &z)
00133 {
00134 z.clear();
00135 z.reserve(x.size() * y.size());
00136 for (typename std::vector<AnnotatedFNS>::const_iterator p = x.begin();
00137 p != x.end(); ++p) {
00138 const AnnotatedFNS &a = *p;
00139 for (typename std::vector<AnnotatedFNS>::const_iterator q = y.begin();
00140 q != y.end(); ++q) {
00141 const AnnotatedFNS &b = *q;
00142
00143 z.resize(z.size()+1);
00144 AnnotatedFNS &c = z.back();
00145
00146 c.fns.reserve(a.fns.size() + b.fns.size());
00147 c.fns.assign(a.fns.begin(), a.fns.end());
00148 c.fns.insert(c.fns.end(), b.fns.begin(), b.fns.end());
00149
00150 c.fragment.reserve(a.fragment.size() + b.fragment.size());
00151 c.fragment.assign(a.fragment.begin(), a.fragment.end());
00152 c.fragment.insert(c.fragment.end(), b.fragment.begin(), b.fragment.end());
00153 }
00154 }
00155 }
00156
00157 template<typename Callback>
00158 bool RuleMatcherHyperTree<Callback>::MatchChildren(
00159 const std::vector<Forest::Vertex *> &children,
00160 const HyperPath::NodeSeq &edgeLabel,
00161 std::size_t pos,
00162 std::size_t subSeqSize)
00163 {
00164 if (children.size() != subSeqSize) {
00165 return false;
00166 }
00167 for (size_t i = 0; i < subSeqSize; ++i) {
00168 if (edgeLabel[pos+i] != children[i]->pvertex.symbol[0]->GetId()) {
00169 return false;
00170 }
00171 }
00172 return true;
00173 }
00174
00175 template<typename Callback>
00176 int RuleMatcherHyperTree<Callback>::CountCommas(const HyperPath::NodeSeq &seq)
00177 {
00178 int count = 0;
00179 for (std::vector<std::size_t>::const_iterator p = seq.begin();
00180 p != seq.end(); ++p) {
00181 if (*p == HyperPath::kComma) {
00182 ++count;
00183 }
00184 }
00185 return count;
00186 }
00187
00188 template<typename Callback>
00189 int RuleMatcherHyperTree<Callback>::SubSeqLength(const HyperPath::NodeSeq &seq,
00190 int pos)
00191 {
00192 int length = 0;
00193 HyperPath::NodeSeq::size_type curpos = pos;
00194 while (curpos != seq.size() && seq[curpos] != HyperPath::kComma) {
00195 ++curpos;
00196 ++length;
00197 }
00198 return length;
00199 }
00200
00201 }
00202 }
00203 }