00001 #include "lm/filter/phrase.hh"
00002
00003 #include "lm/filter/format.hh"
00004
00005 #include <algorithm>
00006 #include <functional>
00007 #include <iostream>
00008 #include <queue>
00009 #include <string>
00010 #include <vector>
00011
00012 #include <cctype>
00013
00014 namespace lm {
00015 namespace phrase {
00016
00017 unsigned int ReadMultiple(std::istream &in, Substrings &out) {
00018 bool sentence_content = false;
00019 unsigned int sentence_id = 0;
00020 std::vector<Hash> phrase;
00021 std::string word;
00022 while (in) {
00023 char c;
00024
00025 while (!isspace(c = in.get()) && in) word += c;
00026
00027 if (!in) c = '\n';
00028
00029 if (!word.empty()) {
00030 phrase.push_back(util::MurmurHashNative(word.data(), word.size()));
00031 word.clear();
00032 }
00033 if (c == ' ') continue;
00034
00035 if (!phrase.empty()) {
00036 sentence_content = true;
00037 out.AddPhrase(sentence_id, phrase.begin(), phrase.end());
00038 phrase.clear();
00039 }
00040 if (c == '\t' || c == '\v') continue;
00041
00042 if (sentence_content) {
00043 ++sentence_id;
00044 sentence_content = false;
00045 }
00046 }
00047 if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit);
00048 return sentence_id + sentence_content;
00049 }
00050
00051 namespace {
00052 typedef unsigned int Sentence;
00053 typedef std::vector<Sentence> Sentences;
00054 }
00055
00056 namespace detail {
00057
00058 const StringPiece kEndSentence("</s>");
00059
00060 class Arc {
00061 public:
00062 Arc() {}
00063
00064
00065 void SetPhrase(detail::Vertex &from, detail::Vertex &to, const Sentences &intersect) {
00066 Set(to, intersect);
00067 from_ = &from;
00068 }
00069
00070
00071
00072
00073
00074 void SetRight(detail::Vertex &to, const Sentences &complete) {
00075 Set(to, complete);
00076 from_ = NULL;
00077 }
00078
00079 Sentence Current() const {
00080 return *current_;
00081 }
00082
00083 bool Empty() const {
00084 return current_ == last_;
00085 }
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097 void LowerBound(const Sentence to);
00098
00099 private:
00100 void Set(detail::Vertex &to, const Sentences &sentences);
00101
00102 const Sentence *current_;
00103 const Sentence *last_;
00104 detail::Vertex *from_;
00105 };
00106
00107 struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> {
00108 bool operator()(const Arc *first, const Arc *second) const {
00109 return first->Current() > second->Current();
00110 }
00111 };
00112
00113 class Vertex {
00114 public:
00115 Vertex() : current_(0) {}
00116
00117 Sentence Current() const {
00118 return current_;
00119 }
00120
00121 bool Empty() const {
00122 return incoming_.empty();
00123 }
00124
00125 void LowerBound(const Sentence to);
00126
00127 private:
00128 friend class Arc;
00129
00130 void AddIncoming(Arc *arc) {
00131 if (!arc->Empty()) incoming_.push(arc);
00132 }
00133
00134 unsigned int current_;
00135 std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_;
00136 };
00137
00138 void Arc::LowerBound(const Sentence to) {
00139 current_ = std::lower_bound(current_, last_, to);
00140
00141
00142 if (!from_ || Empty() || (Current() > to)) return;
00143 assert(Current() == to);
00144 from_->LowerBound(to);
00145 if (from_->Empty()) {
00146 current_ = last_;
00147 return;
00148 }
00149 assert(from_->Current() >= to);
00150 if (from_->Current() > to) {
00151 current_ = std::lower_bound(current_ + 1, last_, from_->Current());
00152 }
00153 }
00154
00155 void Arc::Set(Vertex &to, const Sentences &sentences) {
00156 current_ = &*sentences.begin();
00157 last_ = &*sentences.end();
00158 to.AddIncoming(this);
00159 }
00160
00161 void Vertex::LowerBound(const Sentence to) {
00162 if (Empty()) return;
00163
00164 while (true) {
00165 Arc *top = incoming_.top();
00166 if (top->Current() > to) {
00167 current_ = top->Current();
00168 return;
00169 }
00170
00171
00172 incoming_.pop();
00173 top->LowerBound(to);
00174 if (!top->Empty()) {
00175 incoming_.push(top);
00176 if (top->Current() == to) {
00177 current_ = to;
00178 return;
00179 }
00180 } else if (Empty()) {
00181 return;
00182 }
00183 }
00184 }
00185
00186 }
00187
00188 namespace {
00189
00190 void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, detail::Vertex *const vertices, detail::Arc *free_arc) {
00191 using detail::Vertex;
00192 using detail::Arc;
00193 assert(!hashes.empty());
00194
00195 const Hash *const first_word = &*hashes.begin();
00196 const Hash *const last_word = &*hashes.end() - 1;
00197
00198 Hash hash = 0;
00199 const Sentences *found;
00200
00201 {
00202 Vertex *vertex = vertices;
00203 for (const Hash *word = first_word; ; ++word, ++vertex) {
00204 hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word);
00205
00206 if (word == last_word) {
00207 if (phrase.FindSubstring(hash, found))
00208 (free_arc++)->SetRight(*vertex, *found);
00209 break;
00210 }
00211 if (!phrase.FindRight(hash, found)) break;
00212 (free_arc++)->SetRight(*vertex, *found);
00213 }
00214 }
00215
00216
00217 Vertex *vertex_from = vertices;
00218 for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) {
00219 hash = 0;
00220 Vertex *vertex_to = vertex_from + 1;
00221 for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) {
00222
00223 hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to);
00224
00225 if (word_to == last_word) {
00226 if (phrase.FindLeft(hash, found))
00227 (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
00228 break;
00229 }
00230 if (!phrase.FindPhrase(hash, found)) break;
00231 (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
00232 }
00233 }
00234 }
00235
00236 }
00237
00238 namespace detail {
00239
00240
00241 ConditionCommon::ConditionCommon(const Substrings &substrings) : substrings_(substrings) {}
00242
00243
00244 ConditionCommon::ConditionCommon(const ConditionCommon &from) : substrings_(from.substrings_) {}
00245
00246 ConditionCommon::~ConditionCommon() {}
00247
00248 detail::Vertex &ConditionCommon::MakeGraph() {
00249 assert(!hashes_.empty());
00250 vertices_.clear();
00251 vertices_.resize(hashes_.size());
00252 arcs_.clear();
00253
00254 arcs_.resize(((hashes_.size() + 1) * hashes_.size()) / 2);
00255 BuildGraph(substrings_, hashes_, &*vertices_.begin(), &*arcs_.begin());
00256 return vertices_[hashes_.size() - 1];
00257 }
00258
00259 }
00260
00261 bool Union::Evaluate() {
00262 detail::Vertex &last_vertex = MakeGraph();
00263 unsigned int lower = 0;
00264 while (true) {
00265 last_vertex.LowerBound(lower);
00266 if (last_vertex.Empty()) return false;
00267 if (last_vertex.Current() == lower) return true;
00268 lower = last_vertex.Current();
00269 }
00270 }
00271
00272 template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) {
00273 detail::Vertex &last_vertex = MakeGraph();
00274 unsigned int lower = 0;
00275 while (true) {
00276 last_vertex.LowerBound(lower);
00277 if (last_vertex.Empty()) return;
00278 if (last_vertex.Current() == lower) {
00279 output.SingleAddNGram(lower, line);
00280 ++lower;
00281 } else {
00282 lower = last_vertex.Current();
00283 }
00284 }
00285 }
00286
00287 template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output);
00288 template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output);
00289 template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output);
00290
00291 }
00292 }