00001 #include "ForestTsgFilter.h"
00002
00003 #include <boost/make_shared.hpp>
00004
00005 namespace MosesTraining
00006 {
00007 namespace Syntax
00008 {
00009 namespace FilterRuleTable
00010 {
00011
00012
00013
00014
00015
00016
00017 const std::size_t ForestTsgFilter::kMatchLimit = 10000;
00018
00019 ForestTsgFilter::ForestTsgFilter(
00020 const std::vector<boost::shared_ptr<StringForest> > &sentences)
00021 {
00022
00023 m_sentences.reserve(sentences.size());
00024 for (std::vector<boost::shared_ptr<StringForest> >::const_iterator p =
00025 sentences.begin(); p != sentences.end(); ++p) {
00026 m_sentences.push_back(StringForestToIdForest(**p));
00027 }
00028
00029
00030 m_idToSentence.resize(m_testVocab.Size());
00031 for (std::size_t i = 0; i < m_sentences.size(); ++i) {
00032 const IdForest &forest = *(m_sentences[i]);
00033 for (std::vector<IdForest::Vertex *>::const_iterator
00034 p = forest.vertices.begin(); p != forest.vertices.end(); ++p) {
00035 m_idToSentence[(*p)->value.id][i].push_back(*p);
00036 }
00037 }
00038 }
00039
00040 boost::shared_ptr<ForestTsgFilter::IdForest>
00041 ForestTsgFilter::StringForestToIdForest(const StringForest &f)
00042 {
00043 typedef StringForest::Vertex StringVertex;
00044 typedef StringForest::Hyperedge StringHyperedge;
00045 typedef IdForest::Vertex IdVertex;
00046 typedef IdForest::Hyperedge IdHyperedge;
00047
00048 boost::shared_ptr<IdForest> g = boost::make_shared<IdForest>();
00049
00050
00051 boost::unordered_map<const StringVertex *, const IdVertex *> vertexMap;
00052
00053
00054 for (std::vector<StringVertex *>::const_iterator p = f.vertices.begin();
00055 p != f.vertices.end(); ++p) {
00056 const StringVertex *v = *p;
00057 IdVertex *w = new IdVertex();
00058 w->value.id = m_testVocab.Insert(v->value.symbol);
00059 w->value.start = v->value.start;
00060 w->value.end = v->value.end;
00061 g->vertices.push_back(w);
00062 vertexMap[v] = w;
00063 }
00064
00065
00066 for (std::vector<StringVertex *>::const_iterator p = f.vertices.begin();
00067 p != f.vertices.end(); ++p) {
00068 for (std::vector<StringHyperedge *>::const_iterator
00069 q = (*p)->incoming.begin(); q != (*p)->incoming.end(); ++q) {
00070 IdHyperedge *e = new IdHyperedge();
00071 e->head = const_cast<IdVertex *>(vertexMap[(*q)->head]);
00072 e->tail.reserve((*q)->tail.size());
00073 for (std::vector<StringVertex*>::const_iterator
00074 r = (*q)->tail.begin(); r != (*q)->tail.end(); ++r) {
00075 e->tail.push_back(const_cast<IdVertex *>(vertexMap[*r]));
00076 }
00077 e->head->incoming.push_back(e);
00078 }
00079 }
00080
00081 return g;
00082 }
00083
00084 bool ForestTsgFilter::MatchFragment(const IdTree &fragment,
00085 const std::vector<IdTree *> &leaves)
00086 {
00087 typedef std::vector<const IdTree *> TreeVec;
00088
00089
00090 m_matchCount = 0;
00091
00092
00093
00094
00095
00096
00097 const IdTree *rarestLeaf = leaves[0];
00098 std::size_t lowestCount = m_idToSentence[rarestLeaf->value()].size();
00099 for (std::size_t i = 1; i < leaves.size(); ++i) {
00100 const IdTree *leaf = leaves[i];
00101 std::size_t count = m_idToSentence[leaf->value()].size();
00102 if (count < lowestCount) {
00103 lowestCount = count;
00104 rarestLeaf = leaf;
00105 }
00106 }
00107
00108
00109
00110 const InnerMap &leafSentenceMap = m_idToSentence[rarestLeaf->value()];
00111 const InnerMap &rootSentenceMap = m_idToSentence[fragment.value()];
00112
00113 std::vector<std::pair<std::size_t, std::size_t> > spans;
00114
00115 for (InnerMap::const_iterator p = leafSentenceMap.begin();
00116 p != leafSentenceMap.end(); ++p) {
00117 std::size_t i = p->first;
00118
00119
00120 InnerMap::const_iterator q = rootSentenceMap.find(i);
00121 if (q == rootSentenceMap.end()) {
00122 continue;
00123 }
00124 const std::vector<const IdForest::Vertex*> &candidates = q->second;
00125
00126 spans.clear();
00127 for (std::vector<const IdForest::Vertex*>::const_iterator
00128 r = p->second.begin(); r != p->second.end(); ++r) {
00129 spans.push_back(std::make_pair((*r)->value.start, (*r)->value.end));
00130 }
00131
00132 for (std::vector<const IdForest::Vertex*>::const_iterator
00133 r = candidates.begin(); r != candidates.end(); ++r) {
00134 const IdForest::Vertex &v = **r;
00135
00136
00137 if (v.value.end - v.value.start + 1 < leaves.size()) {
00138 continue;
00139 }
00140
00141 bool covered = false;
00142 for (std::vector<std::pair<std::size_t, std::size_t> >::const_iterator
00143 s = spans.begin(); s != spans.end(); ++s) {
00144 if (v.value.start <= s->first && v.value.end >= s->second) {
00145 covered = true;
00146 break;
00147 }
00148 }
00149 if (!covered) {
00150 continue;
00151 }
00152
00153 if (MatchFragment(fragment, v)) {
00154 return true;
00155 }
00156 }
00157 }
00158 return false;
00159 }
00160
00161 bool ForestTsgFilter::MatchFragment(const IdTree &fragment,
00162 const IdForest::Vertex &v)
00163 {
00164 if (++m_matchCount >= kMatchLimit) {
00165 return true;
00166 }
00167 if (fragment.value() != v.value.id) {
00168 return false;
00169 }
00170 const std::vector<IdTree*> &children = fragment.children();
00171 if (children.empty()) {
00172 return true;
00173 }
00174 for (std::vector<IdForest::Hyperedge *>::const_iterator
00175 p = v.incoming.begin(); p != v.incoming.end(); ++p) {
00176 const std::vector<IdForest::Vertex*> &tail = (*p)->tail;
00177 if (children.size() != tail.size()) {
00178 continue;
00179 }
00180 bool match = true;
00181 for (std::size_t i = 0; i < children.size(); ++i) {
00182 if (!MatchFragment(*children[i], *tail[i])) {
00183 match = false;
00184 break;
00185 }
00186 }
00187 if (match) {
00188 return true;
00189 }
00190 }
00191 return false;
00192 }
00193
00194 }
00195 }
00196 }