00001 #include "RuleTrieCYKPlus.h"
00002
00003 #include <map>
00004 #include <vector>
00005
00006 #include <boost/functional/hash.hpp>
00007 #include <boost/unordered_map.hpp>
00008 #include <boost/version.hpp>
00009
00010 #include "moses/NonTerminal.h"
00011 #include "moses/TargetPhrase.h"
00012 #include "moses/TargetPhraseCollection.h"
00013 #include "moses/Util.h"
00014 #include "moses/Word.h"
00015
00016 namespace Moses
00017 {
00018 namespace Syntax
00019 {
00020 namespace S2T
00021 {
00022
00023 void RuleTrieCYKPlus::Node::Prune(std::size_t tableLimit)
00024 {
00025
00026 for (SymbolMap::iterator p = m_sourceTermMap.begin();
00027 p != m_sourceTermMap.end(); ++p) {
00028 p->second.Prune(tableLimit);
00029 }
00030 for (SymbolMap::iterator p = m_nonTermMap.begin();
00031 p != m_nonTermMap.end(); ++p) {
00032 p->second.Prune(tableLimit);
00033 }
00034
00035
00036 m_targetPhraseCollection->Prune(true, tableLimit);
00037 }
00038
00039 void RuleTrieCYKPlus::Node::Sort(std::size_t tableLimit)
00040 {
00041
00042 for (SymbolMap::iterator p = m_sourceTermMap.begin();
00043 p != m_sourceTermMap.end(); ++p) {
00044 p->second.Sort(tableLimit);
00045 }
00046 for (SymbolMap::iterator p = m_nonTermMap.begin();
00047 p != m_nonTermMap.end(); ++p) {
00048 p->second.Sort(tableLimit);
00049 }
00050
00051
00052 m_targetPhraseCollection->Sort(true, tableLimit);
00053 }
00054
00055 RuleTrieCYKPlus::Node *RuleTrieCYKPlus::Node::GetOrCreateChild(
00056 const Word &sourceTerm)
00057 {
00058 return &m_sourceTermMap[sourceTerm];
00059 }
00060
00061 RuleTrieCYKPlus::Node *RuleTrieCYKPlus::Node::GetOrCreateNonTerminalChild(const Word &targetNonTerm)
00062 {
00063 UTIL_THROW_IF2(!targetNonTerm.IsNonTerminal(),
00064 "Not a non-terminal: " << targetNonTerm);
00065
00066 return &m_nonTermMap[targetNonTerm];
00067 }
00068
00069 const RuleTrieCYKPlus::Node *RuleTrieCYKPlus::Node::GetChild(
00070 const Word &sourceTerm) const
00071 {
00072 UTIL_THROW_IF2(sourceTerm.IsNonTerminal(),
00073 "Not a terminal: " << sourceTerm);
00074
00075 SymbolMap::const_iterator p = m_sourceTermMap.find(sourceTerm);
00076 return (p == m_sourceTermMap.end()) ? NULL : &p->second;
00077 }
00078
00079 const RuleTrieCYKPlus::Node *RuleTrieCYKPlus::Node::GetNonTerminalChild(
00080 const Word &targetNonTerm) const
00081 {
00082 UTIL_THROW_IF2(!targetNonTerm.IsNonTerminal(),
00083 "Not a non-terminal: " << targetNonTerm);
00084
00085 SymbolMap::const_iterator p = m_nonTermMap.find(targetNonTerm);
00086 return (p == m_nonTermMap.end()) ? NULL : &p->second;
00087 }
00088
00089 TargetPhraseCollection::shared_ptr
00090 RuleTrieCYKPlus::
00091 GetOrCreateTargetPhraseCollection(const Phrase &source,
00092 const TargetPhrase &target,
00093 const Word *sourceLHS)
00094 {
00095 Node &currNode = GetOrCreateNode(source, target, sourceLHS);
00096 return currNode.GetTargetPhraseCollection();
00097 }
00098
00099 RuleTrieCYKPlus::Node &RuleTrieCYKPlus::GetOrCreateNode(
00100 const Phrase &source, const TargetPhrase &target, const Word *sourceLHS)
00101 {
00102 const std::size_t size = source.GetSize();
00103
00104 const AlignmentInfo &alignmentInfo = target.GetAlignNonTerm();
00105 AlignmentInfo::const_iterator iterAlign = alignmentInfo.begin();
00106
00107 Node *currNode = &m_root;
00108 for (std::size_t pos = 0 ; pos < size ; ++pos) {
00109 const Word& word = source.GetWord(pos);
00110
00111 if (word.IsNonTerminal()) {
00112 UTIL_THROW_IF2(iterAlign == alignmentInfo.end(),
00113 "No alignment for non-term at position " << pos);
00114 UTIL_THROW_IF2(iterAlign->first != pos,
00115 "Alignment info incorrect at position " << pos);
00116 std::size_t targetNonTermInd = iterAlign->second;
00117 ++iterAlign;
00118 const Word &targetNonTerm = target.GetWord(targetNonTermInd);
00119 currNode = currNode->GetOrCreateNonTerminalChild(targetNonTerm);
00120 } else {
00121 currNode = currNode->GetOrCreateChild(word);
00122 }
00123
00124 UTIL_THROW_IF2(currNode == NULL, "Node not found at position " << pos);
00125 }
00126
00127 return *currNode;
00128 }
00129
00130 void RuleTrieCYKPlus::SortAndPrune(std::size_t tableLimit)
00131 {
00132 if (tableLimit) {
00133 m_root.Sort(tableLimit);
00134 }
00135 }
00136
00137 bool RuleTrieCYKPlus::HasPreterminalRule(const Word &w) const
00138 {
00139 const Node::SymbolMap &map = m_root.GetTerminalMap();
00140 Node::SymbolMap::const_iterator p = map.find(w);
00141 return p != map.end() && p->second.HasRules();
00142 }
00143
00144 }
00145 }
00146 }