00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "AlignmentGraph.h"
00021
00022 #include <algorithm>
00023 #include <cassert>
00024 #include <cstdlib>
00025 #include <memory>
00026 #include <stack>
00027
00028 #include "SyntaxTree.h"
00029
00030 #include "ComposedRule.h"
00031 #include "Node.h"
00032 #include "Options.h"
00033 #include "Subgraph.h"
00034
00035 namespace MosesTraining
00036 {
00037 namespace Syntax
00038 {
00039 namespace GHKM
00040 {
00041
00042 AlignmentGraph::AlignmentGraph(const SyntaxTree *t,
00043 const std::vector<std::string> &s,
00044 const Alignment &a)
00045 {
00046
00047 m_root = CopyParseTree(t);
00048
00049
00050 m_sourceNodes.reserve(s.size());
00051 for (std::vector<std::string>::const_iterator p(s.begin());
00052 p != s.end(); ++p) {
00053 m_sourceNodes.push_back(new Node(*p, SOURCE));
00054 }
00055
00056
00057
00058 std::vector<Node *> targetTreeLeaves;
00059 GetTargetTreeLeaves(m_root, targetTreeLeaves);
00060 for (Alignment::const_iterator p(a.begin()); p != a.end(); ++p) {
00061 Node *src = m_sourceNodes[p->first];
00062 Node *tgt = targetTreeLeaves[p->second];
00063 src->AddParent(tgt);
00064 tgt->AddChild(src);
00065 }
00066
00067
00068 AttachUnalignedSourceWords();
00069
00070
00071 std::vector<Node *>::const_iterator p(m_sourceNodes.begin());
00072 for (int i = 0; p != m_sourceNodes.end(); ++p, ++i) {
00073 (*p)->PropagateIndex(i);
00074 }
00075
00076
00077 CalcComplementSpans(m_root);
00078 }
00079
00080 AlignmentGraph::~AlignmentGraph()
00081 {
00082 for (std::vector<Node *>::iterator p(m_sourceNodes.begin());
00083 p != m_sourceNodes.end(); ++p) {
00084 delete *p;
00085 }
00086 for (std::vector<Node *>::iterator p(m_targetNodes.begin());
00087 p != m_targetNodes.end(); ++p) {
00088 delete *p;
00089 }
00090 }
00091
00092 Subgraph AlignmentGraph::ComputeMinimalFrontierGraphFragment(
00093 Node *root,
00094 const std::set<Node *> &frontierSet)
00095 {
00096 std::stack<Node *> expandableNodes;
00097 std::set<const Node *> expandedNodes;
00098
00099 if (root->IsSink()) {
00100 expandedNodes.insert(root);
00101 } else {
00102 expandableNodes.push(root);
00103 }
00104
00105 while (!expandableNodes.empty()) {
00106 Node *n = expandableNodes.top();
00107 expandableNodes.pop();
00108
00109 const std::vector<Node *> &children = n->GetChildren();
00110
00111 for (std::vector<Node *>::const_iterator p(children.begin());
00112 p != children.end(); ++p) {
00113 Node *child = *p;
00114 if (child->IsSink()) {
00115 expandedNodes.insert(child);
00116 continue;
00117 }
00118 std::set<Node *>::const_iterator q = frontierSet.find(child);
00119 if (q == frontierSet.end()) {
00120 expandableNodes.push(child);
00121 } else if (child->GetType() == TARGET) {
00122 expandableNodes.push(child);
00123 } else {
00124 expandedNodes.insert(child);
00125 }
00126 }
00127 }
00128
00129 return Subgraph(root, expandedNodes);
00130 }
00131
00132 void AlignmentGraph::ExtractMinimalRules(const Options &options)
00133 {
00134
00135 std::set<Node *> frontierSet;
00136 ComputeFrontierSet(m_root, options, frontierSet);
00137
00138
00139 std::vector<Subgraph> fragments;
00140 fragments.reserve(frontierSet.size());
00141 for (std::set<Node *>::iterator p(frontierSet.begin());
00142 p != frontierSet.end(); ++p) {
00143 Node *root = *p;
00144 Subgraph fragment = ComputeMinimalFrontierGraphFragment(root, frontierSet);
00145 assert(!fragment.IsTrivial());
00146
00147
00148 if (root->GetType() == TREE && !root->GetSpan().empty()) {
00149 root->AddRule(new Subgraph(fragment));
00150 }
00151 }
00152 }
00153
00154 void AlignmentGraph::ExtractComposedRules(const Options &options)
00155 {
00156 ExtractComposedRules(m_root, options);
00157 }
00158
00159 void AlignmentGraph::ExtractComposedRules(Node *node, const Options &options)
00160 {
00161
00162 const std::vector<Node *> &children = node->GetChildren();
00163 for (std::vector<Node *>::const_iterator p(children.begin());
00164 p != children.end(); ++p) {
00165 ExtractComposedRules(*p, options);
00166 }
00167
00168
00169
00170 const std::vector<const Subgraph*> &rules = node->GetRules();
00171 assert(rules.size() <= 1);
00172 if (rules.empty()) {
00173 return;
00174 }
00175
00176
00177 ComposedRule cr(*(rules[0]));
00178 if (!cr.GetOpenAttachmentPoint()) {
00179
00180 return;
00181 }
00182
00183 std::queue<ComposedRule> queue;
00184 queue.push(cr);
00185 while (!queue.empty()) {
00186 ComposedRule cr = queue.front();
00187 queue.pop();
00188 const Node *attachmentPoint = cr.GetOpenAttachmentPoint();
00189 assert(attachmentPoint);
00190 assert(attachmentPoint != node);
00191
00192
00193
00194 const std::vector<const Subgraph*> &rules = attachmentPoint->GetRules();
00195 for (std::vector<const Subgraph*>::const_iterator p = rules.begin();
00196 p != rules.end(); ++p) {
00197 assert((*p)->GetRoot()->GetType() == TREE);
00198 ComposedRule *cr2 = cr.AttemptComposition(**p, options);
00199 if (cr2) {
00200 node->AddRule(new Subgraph(cr2->CreateSubgraph()));
00201 if (cr2->GetOpenAttachmentPoint()) {
00202 queue.push(*cr2);
00203 }
00204 delete cr2;
00205 }
00206 }
00207
00208 cr.CloseAttachmentPoint();
00209 if (cr.GetOpenAttachmentPoint()) {
00210 queue.push(cr);
00211 }
00212 }
00213 }
00214
00215 Node *AlignmentGraph::CopyParseTree(const SyntaxTree *root)
00216 {
00217 NodeType nodeType = (root->IsLeaf()) ? TARGET : TREE;
00218
00219 std::auto_ptr<Node> n(new Node(root->value().label, nodeType));
00220
00221 if (nodeType == TREE) {
00222 float score = 0.0f;
00223 SyntaxNode::AttributeMap::const_iterator p =
00224 root->value().attributes.find("pcfg");
00225 if (p != root->value().attributes.end()) {
00226 score = std::atof(p->second.c_str());
00227 }
00228 n->SetPcfgScore(score);
00229 }
00230
00231 const std::vector<SyntaxTree *> &children = root->children();
00232 std::vector<Node *> childNodes;
00233 childNodes.reserve(children.size());
00234 for (std::vector<SyntaxTree *>::const_iterator p(children.begin());
00235 p != children.end(); ++p) {
00236 Node *child = CopyParseTree(*p);
00237 child->AddParent(n.get());
00238 childNodes.push_back(child);
00239 }
00240 n->SetChildren(childNodes);
00241
00242 Node *p = n.release();
00243 m_targetNodes.push_back(p);
00244 return p;
00245 }
00246
00247
00248
00249 void AlignmentGraph::ComputeFrontierSet(Node *root,
00250 const Options &options,
00251 std::set<Node *> &frontierSet) const
00252 {
00253
00254
00255
00256 if (root->GetType() != TREE || root->GetSpan().empty()) {
00257 return;
00258 }
00259
00260 if (IsFrontierNode(*root, options)) {
00261 frontierSet.insert(root);
00262 }
00263
00264
00265 const std::vector<Node *> &children = root->GetChildren();
00266 for (std::vector<Node *>::const_iterator p(children.begin());
00267 p != children.end(); ++p) {
00268 ComputeFrontierSet(*p, options, frontierSet);
00269 }
00270 }
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281 bool AlignmentGraph::IsFrontierNode(const Node &n, const Options &options) const
00282 {
00283
00284 if (n.GetType() != TREE || n.GetSpan().empty()) {
00285 return false;
00286 }
00287
00288 if (SpansIntersect(n.GetComplementSpan(), Closure(n.GetSpan()))) {
00289 return false;
00290 }
00291
00292
00293
00294 assert(n.GetParents().size() <= 1);
00295 if (!options.allowUnary &&
00296 !n.GetParents().empty() &&
00297 n.GetParents()[0]->GetSpan() == n.GetSpan()) {
00298 return false;
00299 }
00300 return true;
00301 }
00302
00303 void AlignmentGraph::CalcComplementSpans(Node *root)
00304 {
00305 Span compSpan;
00306 std::set<Node *> siblings;
00307
00308 const std::vector<Node *> &parents = root->GetParents();
00309 for (std::vector<Node *>::const_iterator p(parents.begin());
00310 p != parents.end(); ++p) {
00311 const Span &parentCompSpan = (*p)->GetComplementSpan();
00312 compSpan.insert(parentCompSpan.begin(), parentCompSpan.end());
00313 const std::vector<Node *> &c = (*p)->GetChildren();
00314 siblings.insert(c.begin(), c.end());
00315 }
00316
00317 for (std::set<Node *>::iterator p(siblings.begin());
00318 p != siblings.end(); ++p) {
00319 if (*p == root) {
00320 continue;
00321 }
00322 const Span &siblingSpan = (*p)->GetSpan();
00323 compSpan.insert(siblingSpan.begin(), siblingSpan.end());
00324 }
00325
00326 root->SetComplementSpan(compSpan);
00327
00328 const std::vector<Node *> &children = root->GetChildren();
00329 for (std::vector<Node *>::const_iterator p(children.begin());
00330 p != children.end(); ++p) {
00331 CalcComplementSpans(*p);
00332 }
00333 }
00334
00335 void AlignmentGraph::GetTargetTreeLeaves(Node *root,
00336 std::vector<Node *> &leaves)
00337 {
00338 if (root->IsSink()) {
00339 leaves.push_back(root);
00340 } else {
00341 const std::vector<Node *> &children = root->GetChildren();
00342 for (std::vector<Node *>::const_iterator p(children.begin());
00343 p != children.end(); ++p) {
00344 GetTargetTreeLeaves(*p, leaves);
00345 }
00346 }
00347 }
00348
00349 void AlignmentGraph::AttachUnalignedSourceWords()
00350 {
00351
00352 std::set<int> unaligned;
00353 for (size_t i = 0; i < m_sourceNodes.size(); ++i) {
00354 const Node &sourceNode = (*m_sourceNodes[i]);
00355 if (sourceNode.GetParents().empty()) {
00356 unaligned.insert(i);
00357 }
00358 }
00359
00360
00361 for (std::set<int>::iterator p = unaligned.begin();
00362 p != unaligned.end(); ++p) {
00363 int index = *p;
00364 Node *attachmentPoint = DetermineAttachmentPoint(index);
00365 Node *sourceNode = m_sourceNodes[index];
00366 attachmentPoint->AddChild(sourceNode);
00367 sourceNode->AddParent(attachmentPoint);
00368 }
00369 }
00370
00371 Node *AlignmentGraph::DetermineAttachmentPoint(int index)
00372 {
00373
00374 int i = index;
00375 while (--i >= 0) {
00376 if (!m_sourceNodes[i]->GetParents().empty()) {
00377 break;
00378 }
00379 }
00380
00381 if (i == -1) {
00382 return m_root;
00383 }
00384
00385 size_t j = index;
00386 while (++j < m_sourceNodes.size()) {
00387 if (!m_sourceNodes[j]->GetParents().empty()) {
00388 break;
00389 }
00390 }
00391
00392 if (j == m_sourceNodes.size()) {
00393 return m_root;
00394 }
00395
00396
00397 const std::vector<Node *> &leftParents = m_sourceNodes[i]->GetParents();
00398 assert(!leftParents.empty());
00399 const std::vector<Node *> &rightParents = m_sourceNodes[j]->GetParents();
00400 assert(!rightParents.empty());
00401 std::set<Node *> targetSet;
00402 targetSet.insert(leftParents.begin(), leftParents.end());
00403 targetSet.insert(rightParents.begin(), rightParents.end());
00404
00405
00406
00407
00408 Node *lca = Node::LowestCommonAncestor(targetSet.begin(), targetSet.end());
00409 if (lca->GetType() == TARGET) {
00410 assert(lca->GetParents().size() == 1);
00411 return lca->GetParents()[0];
00412 }
00413 return lca;
00414 }
00415
00416 }
00417 }
00418 }