00001 #include "StsgRule.h"
00002
00003 #include <algorithm>
00004
00005 #include "Node.h"
00006 #include "Subgraph.h"
00007
00008 namespace MosesTraining
00009 {
00010 namespace Syntax
00011 {
00012 namespace GHKM
00013 {
00014
00015 StsgRule::StsgRule(const Subgraph &fragment)
00016 : m_targetSide(fragment, true)
00017 {
00018
00019
00020 const std::set<const Node *> &sinkNodes = fragment.GetLeaves();
00021
00022
00023
00024 std::vector<const Node *> productiveSinks;
00025 productiveSinks.reserve(sinkNodes.size());
00026 for (std::set<const Node *>::const_iterator p = sinkNodes.begin();
00027 p != sinkNodes.end(); ++p) {
00028 const Node *sink = *p;
00029 if (!sink->GetSpan().empty()) {
00030 productiveSinks.push_back(sink);
00031 }
00032 }
00033
00034
00035 std::sort(productiveSinks.begin(), productiveSinks.end(), PartitionOrderComp);
00036
00037
00038
00039 std::map<const Node *, std::vector<int> > sinkToSourceIndices;
00040 std::map<const Node *, int> nonTermSinkToSourceIndex;
00041
00042 m_sourceSide.reserve(productiveSinks.size());
00043 int srcIndex = 0;
00044 int nonTermCount = 0;
00045 for (std::vector<const Node *>::const_iterator p = productiveSinks.begin();
00046 p != productiveSinks.end(); ++p, ++srcIndex) {
00047 const Node &sink = **p;
00048 if (sink.GetType() == TREE) {
00049 m_sourceSide.push_back(Symbol("X", NonTerminal));
00050 sinkToSourceIndices[&sink].push_back(srcIndex);
00051 nonTermSinkToSourceIndex[&sink] = nonTermCount++;
00052 } else {
00053 assert(sink.GetType() == SOURCE);
00054 m_sourceSide.push_back(Symbol(sink.GetLabel(), Terminal));
00055
00056 const std::vector<Node *> &parents(sink.GetParents());
00057 for (std::vector<Node *>::const_iterator q = parents.begin();
00058 q != parents.end(); ++q) {
00059 if ((*q)->GetType() == TARGET) {
00060 sinkToSourceIndices[*q].push_back(srcIndex);
00061 }
00062 }
00063 }
00064 }
00065
00066
00067
00068 std::vector<const Node *> targetLeaves;
00069 m_targetSide.GetTargetLeaves(targetLeaves);
00070
00071 m_alignment.reserve(targetLeaves.size());
00072 m_nonTermAlignment.resize(nonTermCount);
00073
00074 for (int i = 0, j = 0; i < targetLeaves.size(); ++i) {
00075 const Node *leaf = targetLeaves[i];
00076 assert(leaf->GetType() != SOURCE);
00077 if (leaf->GetSpan().empty()) {
00078 continue;
00079 }
00080 std::map<const Node *, std::vector<int> >::iterator p =
00081 sinkToSourceIndices.find(leaf);
00082 assert(p != sinkToSourceIndices.end());
00083 std::vector<int> &sourceNodes = p->second;
00084 for (std::vector<int>::iterator r = sourceNodes.begin();
00085 r != sourceNodes.end(); ++r) {
00086 int srcIndex = *r;
00087 m_alignment.push_back(std::make_pair(srcIndex, i));
00088 }
00089 if (leaf->GetType() == TREE) {
00090 m_nonTermAlignment[nonTermSinkToSourceIndex[leaf]] = j++;
00091 }
00092 }
00093 }
00094
00095 }
00096 }
00097 }