00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 #include "ScfgRule.h"
00021 
00022 #include <algorithm>
00023 
00024 #include "Node.h"
00025 #include "Subgraph.h"
00026 #include "SyntaxNode.h"
00027 #include "SyntaxNodeCollection.h"
00028 
00029 namespace MosesTraining
00030 {
00031 namespace Syntax
00032 {
00033 namespace GHKM
00034 {
00035 
00036 ScfgRule::ScfgRule(const Subgraph &fragment,
00037                    const SyntaxNodeCollection *sourceNodeCollection)
00038   : m_graphFragment(fragment)
00039   , m_sourceLHS("X", NonTerminal)
00040   , m_targetLHS(fragment.GetRoot()->GetLabel(), NonTerminal)
00041   , m_pcfgScore(fragment.GetPcfgScore())
00042   , m_hasSourceLabels(sourceNodeCollection)
00043 {
00044 
00045   
00046 
00047   const std::set<const Node *> &leaves = fragment.GetLeaves();
00048 
00049   std::vector<const Node *> sourceRHSNodes;
00050   sourceRHSNodes.reserve(leaves.size());
00051   for (std::set<const Node *>::const_iterator p(leaves.begin());
00052        p != leaves.end(); ++p) {
00053     const Node &leaf = **p;
00054     if (!leaf.GetSpan().empty()) {
00055       sourceRHSNodes.push_back(&leaf);
00056     }
00057   }
00058 
00059   std::sort(sourceRHSNodes.begin(), sourceRHSNodes.end(), PartitionOrderComp);
00060 
00061   
00062   
00063   std::map<const Node *, std::vector<int> > sourceOrder;
00064 
00065   m_sourceRHS.reserve(sourceRHSNodes.size());
00066   m_numberOfNonTerminals = 0;
00067   int srcIndex = 0;
00068   for (std::vector<const Node *>::const_iterator p(sourceRHSNodes.begin());
00069        p != sourceRHSNodes.end(); ++p, ++srcIndex) {
00070     const Node &sinkNode = **p;
00071     if (sinkNode.GetType() == TREE) {
00072       m_sourceRHS.push_back(Symbol("X", NonTerminal));
00073       sourceOrder[&sinkNode].push_back(srcIndex);
00074       ++m_numberOfNonTerminals;
00075     } else {
00076       assert(sinkNode.GetType() == SOURCE);
00077       m_sourceRHS.push_back(Symbol(sinkNode.GetLabel(), Terminal));
00078       
00079       const std::vector<Node *> &parents(sinkNode.GetParents());
00080       for (std::vector<Node *>::const_iterator q(parents.begin());
00081            q != parents.end(); ++q) {
00082         if ((*q)->GetType() == TARGET) {
00083           sourceOrder[*q].push_back(srcIndex);
00084         }
00085       }
00086     }
00087     if (sourceNodeCollection) {
00088       
00089       PushSourceLabel(sourceNodeCollection,&sinkNode,"XRHS");
00090     }
00091   }
00092 
00093   
00094 
00095   std::vector<const Node *> targetLeaves;
00096   fragment.GetTargetLeaves(targetLeaves);
00097 
00098   m_alignment.reserve(targetLeaves.size());  
00099   m_targetRHS.reserve(targetLeaves.size());
00100 
00101   for (std::vector<const Node *>::const_iterator p(targetLeaves.begin());
00102        p != targetLeaves.end(); ++p) {
00103     const Node &leaf = **p;
00104     if (leaf.GetSpan().empty()) {
00105       
00106       
00107       std::vector<std::string> targetWords(leaf.GetTargetWords());
00108       for (std::vector<std::string>::const_iterator q(targetWords.begin());
00109            q != targetWords.end(); ++q) {
00110         m_targetRHS.push_back(Symbol(*q, Terminal));
00111       }
00112     } else if (leaf.GetType() == SOURCE) {
00113       
00114     } else {
00115       SymbolType type = (leaf.GetType() == TREE) ? NonTerminal : Terminal;
00116       m_targetRHS.push_back(Symbol(leaf.GetLabel(), type));
00117 
00118       int tgtIndex = m_targetRHS.size()-1;
00119       std::map<const Node *, std::vector<int> >::iterator q(sourceOrder.find(&leaf));
00120       assert(q != sourceOrder.end());
00121       std::vector<int> &sourceNodes = q->second;
00122       for (std::vector<int>::iterator r(sourceNodes.begin());
00123            r != sourceNodes.end(); ++r) {
00124         int srcIndex = *r;
00125         m_alignment.push_back(std::make_pair(srcIndex, tgtIndex));
00126       }
00127     }
00128   }
00129 
00130   if (sourceNodeCollection) {
00131     
00132     PushSourceLabel(sourceNodeCollection,fragment.GetRoot(),"XLHS");
00133     
00134     
00135 
00136   }
00137 }
00138 
00139 void ScfgRule::PushSourceLabel(const SyntaxNodeCollection *sourceNodeCollection,
00140                                const Node *node,
00141                                const std::string &nonMatchingLabel)
00142 {
00143   ContiguousSpan span = Closure(node->GetSpan());
00144   if (sourceNodeCollection->HasNode(span.first,span.second)) { 
00145     std::vector<SyntaxNode*> sourceLabels =
00146       sourceNodeCollection->GetNodes(span.first,span.second);
00147     if (!sourceLabels.empty()) {
00148       
00149       m_sourceLabels.push_back(sourceLabels.back()->label);
00150     }
00151   } else {
00152     
00153     m_sourceLabels.push_back(nonMatchingLabel);
00154   }
00155 }
00156 
00157 
00158 void ScfgRule::UpdateSourceLabelCoocCounts(std::map< std::string, std::map<std::string,float>* > &coocCounts, float count) const
00159 {
00160   std::map<int, int> sourceToTargetNTMap;
00161   std::map<int, int> targetToSourceNTMap;
00162 
00163   for (Alignment::const_iterator p(m_alignment.begin());
00164        p != m_alignment.end(); ++p) {
00165     if ( m_sourceRHS[p->first].GetType() == NonTerminal ) {
00166       assert(m_targetRHS[p->second].GetType() == NonTerminal);
00167       sourceToTargetNTMap[p->first] = p->second;
00168     }
00169   }
00170 
00171   size_t sourceIndex = 0;
00172   size_t sourceNonTerminalIndex = 0;
00173   for (std::vector<Symbol>::const_iterator p=m_sourceRHS.begin();
00174        p != m_sourceRHS.end(); ++p, ++sourceIndex) {
00175     if ( p->GetType() == NonTerminal ) {
00176       const std::string &sourceLabel = m_sourceLabels[sourceNonTerminalIndex];
00177       int targetIndex = sourceToTargetNTMap[sourceIndex];
00178       const std::string &targetLabel = m_targetRHS[targetIndex].GetValue();
00179       ++sourceNonTerminalIndex;
00180 
00181       std::map<std::string,float>* countMap = NULL;
00182       std::map< std::string, std::map<std::string,float>* >::iterator iter = coocCounts.find(sourceLabel);
00183       if ( iter == coocCounts.end() ) {
00184         std::map<std::string,float> *newCountMap = new std::map<std::string,float>();
00185         std::pair< std::map< std::string, std::map<std::string,float>* >::iterator, bool > inserted =
00186           coocCounts.insert( std::pair< std::string, std::map<std::string,float>* >(sourceLabel, newCountMap) );
00187         assert(inserted.second);
00188         countMap = (inserted.first)->second;
00189       } else {
00190         countMap = iter->second;
00191       }
00192       std::pair< std::map<std::string,float>::iterator, bool > inserted =
00193         countMap->insert( std::pair< std::string,float>(targetLabel, count) );
00194       if ( !inserted.second ) {
00195         (inserted.first)->second += count;
00196       }
00197     }
00198   }
00199 }
00200 
00201 }  
00202 }  
00203 }