00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "ScfgRule.h"
00021
00022 #include <algorithm>
00023
00024 #include "Node.h"
00025 #include "Subgraph.h"
00026 #include "SyntaxNode.h"
00027 #include "SyntaxNodeCollection.h"
00028
00029 namespace MosesTraining
00030 {
00031 namespace Syntax
00032 {
00033 namespace GHKM
00034 {
00035
00036 ScfgRule::ScfgRule(const Subgraph &fragment,
00037 const SyntaxNodeCollection *sourceNodeCollection)
00038 : m_graphFragment(fragment)
00039 , m_sourceLHS("X", NonTerminal)
00040 , m_targetLHS(fragment.GetRoot()->GetLabel(), NonTerminal)
00041 , m_pcfgScore(fragment.GetPcfgScore())
00042 , m_hasSourceLabels(sourceNodeCollection)
00043 {
00044
00045
00046
00047 const std::set<const Node *> &leaves = fragment.GetLeaves();
00048
00049 std::vector<const Node *> sourceRHSNodes;
00050 sourceRHSNodes.reserve(leaves.size());
00051 for (std::set<const Node *>::const_iterator p(leaves.begin());
00052 p != leaves.end(); ++p) {
00053 const Node &leaf = **p;
00054 if (!leaf.GetSpan().empty()) {
00055 sourceRHSNodes.push_back(&leaf);
00056 }
00057 }
00058
00059 std::sort(sourceRHSNodes.begin(), sourceRHSNodes.end(), PartitionOrderComp);
00060
00061
00062
00063 std::map<const Node *, std::vector<int> > sourceOrder;
00064
00065 m_sourceRHS.reserve(sourceRHSNodes.size());
00066 m_numberOfNonTerminals = 0;
00067 int srcIndex = 0;
00068 for (std::vector<const Node *>::const_iterator p(sourceRHSNodes.begin());
00069 p != sourceRHSNodes.end(); ++p, ++srcIndex) {
00070 const Node &sinkNode = **p;
00071 if (sinkNode.GetType() == TREE) {
00072 m_sourceRHS.push_back(Symbol("X", NonTerminal));
00073 sourceOrder[&sinkNode].push_back(srcIndex);
00074 ++m_numberOfNonTerminals;
00075 } else {
00076 assert(sinkNode.GetType() == SOURCE);
00077 m_sourceRHS.push_back(Symbol(sinkNode.GetLabel(), Terminal));
00078
00079 const std::vector<Node *> &parents(sinkNode.GetParents());
00080 for (std::vector<Node *>::const_iterator q(parents.begin());
00081 q != parents.end(); ++q) {
00082 if ((*q)->GetType() == TARGET) {
00083 sourceOrder[*q].push_back(srcIndex);
00084 }
00085 }
00086 }
00087 if (sourceNodeCollection) {
00088
00089 PushSourceLabel(sourceNodeCollection,&sinkNode,"XRHS");
00090 }
00091 }
00092
00093
00094
00095 std::vector<const Node *> targetLeaves;
00096 fragment.GetTargetLeaves(targetLeaves);
00097
00098 m_alignment.reserve(targetLeaves.size());
00099 m_targetRHS.reserve(targetLeaves.size());
00100
00101 for (std::vector<const Node *>::const_iterator p(targetLeaves.begin());
00102 p != targetLeaves.end(); ++p) {
00103 const Node &leaf = **p;
00104 if (leaf.GetSpan().empty()) {
00105
00106
00107 std::vector<std::string> targetWords(leaf.GetTargetWords());
00108 for (std::vector<std::string>::const_iterator q(targetWords.begin());
00109 q != targetWords.end(); ++q) {
00110 m_targetRHS.push_back(Symbol(*q, Terminal));
00111 }
00112 } else if (leaf.GetType() == SOURCE) {
00113
00114 } else {
00115 SymbolType type = (leaf.GetType() == TREE) ? NonTerminal : Terminal;
00116 m_targetRHS.push_back(Symbol(leaf.GetLabel(), type));
00117
00118 int tgtIndex = m_targetRHS.size()-1;
00119 std::map<const Node *, std::vector<int> >::iterator q(sourceOrder.find(&leaf));
00120 assert(q != sourceOrder.end());
00121 std::vector<int> &sourceNodes = q->second;
00122 for (std::vector<int>::iterator r(sourceNodes.begin());
00123 r != sourceNodes.end(); ++r) {
00124 int srcIndex = *r;
00125 m_alignment.push_back(std::make_pair(srcIndex, tgtIndex));
00126 }
00127 }
00128 }
00129
00130 if (sourceNodeCollection) {
00131
00132 PushSourceLabel(sourceNodeCollection,fragment.GetRoot(),"XLHS");
00133
00134
00135
00136 }
00137 }
00138
00139 void ScfgRule::PushSourceLabel(const SyntaxNodeCollection *sourceNodeCollection,
00140 const Node *node,
00141 const std::string &nonMatchingLabel)
00142 {
00143 ContiguousSpan span = Closure(node->GetSpan());
00144 if (sourceNodeCollection->HasNode(span.first,span.second)) {
00145 std::vector<SyntaxNode*> sourceLabels =
00146 sourceNodeCollection->GetNodes(span.first,span.second);
00147 if (!sourceLabels.empty()) {
00148
00149 m_sourceLabels.push_back(sourceLabels.back()->label);
00150 }
00151 } else {
00152
00153 m_sourceLabels.push_back(nonMatchingLabel);
00154 }
00155 }
00156
00157
00158 void ScfgRule::UpdateSourceLabelCoocCounts(std::map< std::string, std::map<std::string,float>* > &coocCounts, float count) const
00159 {
00160 std::map<int, int> sourceToTargetNTMap;
00161 std::map<int, int> targetToSourceNTMap;
00162
00163 for (Alignment::const_iterator p(m_alignment.begin());
00164 p != m_alignment.end(); ++p) {
00165 if ( m_sourceRHS[p->first].GetType() == NonTerminal ) {
00166 assert(m_targetRHS[p->second].GetType() == NonTerminal);
00167 sourceToTargetNTMap[p->first] = p->second;
00168 }
00169 }
00170
00171 size_t sourceIndex = 0;
00172 size_t sourceNonTerminalIndex = 0;
00173 for (std::vector<Symbol>::const_iterator p=m_sourceRHS.begin();
00174 p != m_sourceRHS.end(); ++p, ++sourceIndex) {
00175 if ( p->GetType() == NonTerminal ) {
00176 const std::string &sourceLabel = m_sourceLabels[sourceNonTerminalIndex];
00177 int targetIndex = sourceToTargetNTMap[sourceIndex];
00178 const std::string &targetLabel = m_targetRHS[targetIndex].GetValue();
00179 ++sourceNonTerminalIndex;
00180
00181 std::map<std::string,float>* countMap = NULL;
00182 std::map< std::string, std::map<std::string,float>* >::iterator iter = coocCounts.find(sourceLabel);
00183 if ( iter == coocCounts.end() ) {
00184 std::map<std::string,float> *newCountMap = new std::map<std::string,float>();
00185 std::pair< std::map< std::string, std::map<std::string,float>* >::iterator, bool > inserted =
00186 coocCounts.insert( std::pair< std::string, std::map<std::string,float>* >(sourceLabel, newCountMap) );
00187 assert(inserted.second);
00188 countMap = (inserted.first)->second;
00189 } else {
00190 countMap = iter->second;
00191 }
00192 std::pair< std::map<std::string,float>::iterator, bool > inserted =
00193 countMap->insert( std::pair< std::string,float>(targetLabel, count) );
00194 if ( !inserted.second ) {
00195 (inserted.first)->second += count;
00196 }
00197 }
00198 }
00199 }
00200
00201 }
00202 }
00203 }