00001 #include "StringCfgFilter.h"
00002
00003 #include <algorithm>
00004
00005 #include "util/string_piece_hash.hh"
00006
00007 namespace MosesTraining
00008 {
00009 namespace Syntax
00010 {
00011 namespace FilterRuleTable
00012 {
00013
00014 const std::size_t StringCfgFilter::kMaxNGramLength = 5;
00015
00016 StringCfgFilter::StringCfgFilter(
00017 const std::vector<boost::shared_ptr<std::string> > &sentences)
00018 : m_maxSentenceLength(-1)
00019 {
00020
00021
00022 m_sentenceLengths.reserve(sentences.size());
00023 const util::AnyCharacter delimiter(" \t");
00024 std::vector<Vocabulary::IdType> vocabIds;
00025 for (std::size_t i = 0; i < sentences.size(); ++i) {
00026 vocabIds.clear();
00027 for (util::TokenIter<util::AnyCharacter, true> p(*sentences[i], delimiter);
00028 p; ++p) {
00029 std::string tmp;
00030 p->CopyToString(&tmp);
00031 vocabIds.push_back(m_testVocab.Insert(tmp));
00032 }
00033 AddSentenceNGrams(vocabIds, i);
00034 const int sentenceLength = static_cast<int>(vocabIds.size());
00035 m_sentenceLengths.push_back(sentenceLength);
00036 m_maxSentenceLength = std::max(sentenceLength, m_maxSentenceLength);
00037 }
00038
00039
00040 for (NGramCoordinateMap::iterator p = m_ngramCoordinateMap.begin();
00041 p != m_ngramCoordinateMap.end(); ++p) {
00042 CoordinateTable &ct = p->second;
00043 ct.sentences.reserve(ct.intraSentencePositions.size());
00044 for (boost::unordered_map<int, PositionSeq>::const_iterator
00045 q = ct.intraSentencePositions.begin();
00046 q != ct.intraSentencePositions.end(); ++q) {
00047 ct.sentences.push_back(q->first);
00048 }
00049 std::sort(ct.sentences.begin(), ct.sentences.end());
00050 }
00051 }
00052
00053 void StringCfgFilter::Filter(std::istream &in, std::ostream &out)
00054 {
00055 const util::MultiCharacter fieldDelimiter("|||");
00056 const util::AnyCharacter symbolDelimiter(" \t");
00057
00058 std::string line;
00059 std::string prevLine;
00060 StringPiece source;
00061 std::vector<StringPiece> symbols;
00062 Pattern pattern;
00063 bool keep = true;
00064 int lineNum = 0;
00065
00066 while (std::getline(in, line)) {
00067 ++lineNum;
00068
00069
00070 util::TokenIter<util::MultiCharacter> it(line, fieldDelimiter);
00071
00072
00073
00074
00075
00076 if (*it == source) {
00077 if (keep) {
00078 out << line << std::endl;
00079 }
00080 continue;
00081 }
00082
00083
00084 source = *it;
00085
00086
00087 symbols.clear();
00088 for (util::TokenIter<util::AnyCharacter, true> p(source, symbolDelimiter);
00089 p; ++p) {
00090 symbols.push_back(*p);
00091 }
00092
00093
00094
00095 keep = GeneratePattern(symbols, pattern) && MatchPattern(pattern);
00096 if (keep) {
00097 out << line << std::endl;
00098 }
00099
00100
00101
00102 prevLine.swap(line);
00103 }
00104 }
00105
00106 void StringCfgFilter::AddSentenceNGrams(
00107 const std::vector<Vocabulary::IdType> &s, std::size_t sentNum)
00108 {
00109 const std::size_t len = s.size();
00110
00111 NGram ngram;
00112
00113 for (std::size_t i = 0; i < len; ++i) {
00114
00115
00116 for (std::size_t n = 1; n <= std::min(kMaxNGramLength, len-i); ++n) {
00117 ngram.clear();
00118 for (std::size_t j = 0; j < n; ++j) {
00119 ngram.push_back(s[i+j]);
00120 }
00121 m_ngramCoordinateMap[ngram].intraSentencePositions[sentNum].push_back(i);
00122 }
00123 }
00124 }
00125
00126 bool StringCfgFilter::GeneratePattern(const std::vector<StringPiece> &symbols,
00127 Pattern &pattern) const
00128 {
00129 pattern.subpatterns.clear();
00130 pattern.minGapWidths.clear();
00131
00132 int gapWidth = 0;
00133
00134
00135
00136 if (IsNonTerminal(symbols[0])) {
00137 ++gapWidth;
00138 } else {
00139 pattern.minGapWidths.push_back(0);
00140
00141 Vocabulary::IdType vocabId =
00142 m_testVocab.Lookup(symbols[0], StringPieceCompatibleHash(),
00143 StringPieceCompatibleEquals());
00144 if (vocabId == Vocabulary::NullId()) {
00145 return false;
00146 }
00147 pattern.subpatterns.push_back(NGram(1, vocabId));
00148 }
00149
00150
00151 for (std::size_t i = 1; i < symbols.size()-1; ++i) {
00152
00153 if (IsNonTerminal(symbols[i])) {
00154 ++gapWidth;
00155 continue;
00156 }
00157
00158 if (gapWidth > 0) {
00159 pattern.minGapWidths.push_back(gapWidth);
00160 gapWidth = 0;
00161 pattern.subpatterns.resize(pattern.subpatterns.size()+1);
00162
00163 } else if (pattern.subpatterns.back().size() == kMaxNGramLength) {
00164 pattern.minGapWidths.push_back(0);
00165 pattern.subpatterns.resize(pattern.subpatterns.size()+1);
00166 }
00167
00168 Vocabulary::IdType vocabId =
00169 m_testVocab.Lookup(symbols[i], StringPieceCompatibleHash(),
00170 StringPieceCompatibleEquals());
00171 if (vocabId == Vocabulary::NullId()) {
00172 return false;
00173 }
00174 pattern.subpatterns.back().push_back(vocabId);
00175 }
00176
00177
00178 pattern.minGapWidths.push_back(gapWidth);
00179 return true;
00180 }
00181
00182 bool StringCfgFilter::IsNonTerminal(const StringPiece &symbol) const
00183 {
00184 return symbol.size() >= 3 && symbol[0] == '[' &&
00185 symbol[symbol.size()-1] == ']';
00186 }
00187
00188 bool StringCfgFilter::MatchPattern(const Pattern &pattern) const
00189 {
00190
00191
00192
00193 if (pattern.subpatterns.empty()) {
00194 assert(pattern.minGapWidths.size() == 1);
00195 return pattern.minGapWidths[0] <= m_maxSentenceLength;
00196 }
00197
00198
00199
00200 std::vector<const CoordinateTable *> tables;
00201 for (std::vector<NGram>::const_iterator p = pattern.subpatterns.begin();
00202 p != pattern.subpatterns.end(); ++p) {
00203 NGramCoordinateMap::const_iterator q = m_ngramCoordinateMap.find(*p);
00204
00205
00206 if (q == m_ngramCoordinateMap.end()) {
00207 return false;
00208 }
00209 tables.push_back(&(q->second));
00210 }
00211
00212
00213
00214 std::vector<int> intersection = tables[0]->sentences;
00215 std::vector<int> tmp(intersection.size());
00216 for (std::size_t i = 1; i < tables.size(); ++i) {
00217 std::vector<int>::iterator p = std::set_intersection(
00218 intersection.begin(), intersection.end(), tables[i]->sentences.begin(),
00219 tables[i]->sentences.end(), tmp.begin());
00220 tmp.resize(p-tmp.begin());
00221 if (tmp.empty()) {
00222 return false;
00223 }
00224 intersection.swap(tmp);
00225 }
00226
00227
00228
00229
00230
00231
00232 for (std::vector<int>::const_iterator p = intersection.begin();
00233 p != intersection.end(); ++p) {
00234 if (MatchPattern(pattern, tables, *p)) {
00235 return true;
00236 }
00237 }
00238 return false;
00239 }
00240
00241 bool StringCfgFilter::MatchPattern(
00242 const Pattern &pattern,
00243 std::vector<const CoordinateTable *> &tables,
00244 int sentenceId) const
00245 {
00246 const int sentenceLength = m_sentenceLengths[sentenceId];
00247
00248
00249
00250
00251
00252 std::vector<Range> rangeSet;
00253 std::vector<Range> nextRangeSet;
00254
00255
00256 int minStart = pattern.minGapWidths[0];
00257 int maxStart = sentenceLength - MinWidth(pattern, 0);
00258 rangeSet.push_back(Range(minStart, maxStart));
00259
00260
00261 for (int i = 0; i < pattern.subpatterns.size(); ++i) {
00262
00263 boost::unordered_map<int, PositionSeq>::const_iterator r =
00264 tables[i]->intraSentencePositions.find(sentenceId);
00265 assert(r != tables[i]->intraSentencePositions.end());
00266 const PositionSeq &col = r->second;
00267 for (PositionSeq::const_iterator p = col.begin(); p != col.end(); ++p) {
00268 bool inRange = false;
00269 for (std::vector<Range>::const_iterator q = rangeSet.begin();
00270 q != rangeSet.end(); ++q) {
00271
00272 if (*p >= q->first && *p <= q->second) {
00273 inRange = true;
00274 break;
00275 }
00276 }
00277 if (!inRange) {
00278 continue;
00279 }
00280
00281 if (i+1 == pattern.subpatterns.size()) {
00282 return true;
00283 }
00284 nextRangeSet.push_back(CalcNextRange(pattern, i, *p, sentenceLength));
00285 }
00286 if (nextRangeSet.empty()) {
00287 return false;
00288 }
00289 rangeSet.swap(nextRangeSet);
00290 nextRangeSet.clear();
00291 }
00292 return true;
00293 }
00294
00295 StringCfgFilter::Range StringCfgFilter::CalcNextRange(
00296 const Pattern &pattern, int i, int x, int sentenceLength) const
00297 {
00298 assert(i+1 < pattern.subpatterns.size());
00299 Range range;
00300 if (pattern.minGapWidths[i+1] == 0) {
00301
00302 range.first = range.second = x + pattern.subpatterns[i].size();
00303 } else {
00304 range.first = x + pattern.subpatterns[i].size() + pattern.minGapWidths[i+1];
00305
00306 range.second = sentenceLength - MinWidth(pattern, i+1);
00307 }
00308 return range;
00309 }
00310
00311 int StringCfgFilter::MinWidth(const Pattern &pattern, int i) const
00312 {
00313 int minWidth = 0;
00314 for (; i < pattern.subpatterns.size(); ++i) {
00315 minWidth += pattern.subpatterns[i].size();
00316 minWidth += pattern.minGapWidths[i+1];
00317 }
00318 return minWidth;
00319 }
00320
00321 }
00322 }
00323 }