00001 #include "lm/builder/corpus_count.hh"
00002
00003 #include "lm/builder/payload.hh"
00004 #include "lm/common/ngram.hh"
00005 #include "lm/lm_exception.hh"
00006 #include "lm/vocab.hh"
00007 #include "lm/word_index.hh"
00008 #include "util/file_stream.hh"
00009 #include "util/file.hh"
00010 #include "util/file_piece.hh"
00011 #include "util/murmur_hash.hh"
00012 #include "util/probing_hash_table.hh"
00013 #include "util/scoped.hh"
00014 #include "util/stream/chain.hh"
00015 #include "util/stream/timer.hh"
00016 #include "util/tokenize_piece.hh"
00017
00018 #include <functional>
00019
00020 #include <stdint.h>
00021
00022 namespace lm {
00023 namespace builder {
00024 namespace {
00025
00026 class DedupeHash : public std::unary_function<const WordIndex *, bool> {
00027 public:
00028 explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
00029
00030 std::size_t operator()(const WordIndex *start) const {
00031 return util::MurmurHashNative(start, size_);
00032 }
00033
00034 private:
00035 const std::size_t size_;
00036 };
00037
00038 class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
00039 public:
00040 explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
00041
00042 bool operator()(const WordIndex *first, const WordIndex *second) const {
00043 return !memcmp(first, second, size_);
00044 }
00045
00046 private:
00047 const std::size_t size_;
00048 };
00049
00050 struct DedupeEntry {
00051 typedef WordIndex *Key;
00052 Key GetKey() const { return key; }
00053 void SetKey(WordIndex *to) { key = to; }
00054 Key key;
00055 static DedupeEntry Construct(WordIndex *at) {
00056 DedupeEntry ret;
00057 ret.key = at;
00058 return ret;
00059 }
00060 };
00061
00062
00063
00064 const float kProbingMultiplier = 1.5;
00065
00066 typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
00067
00068 class Writer {
00069 public:
00070 Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
00071 : block_(position), gram_(block_->Get(), order),
00072 dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
00073 dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
00074 buffer_(new WordIndex[order - 1]),
00075 block_size_(position.GetChain().BlockSize()) {
00076 dedupe_.Clear();
00077 assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
00078 if (order == 1) {
00079
00080 AddUnigramWord(kUNK);
00081 AddUnigramWord(kBOS);
00082 }
00083 }
00084
00085 ~Writer() {
00086 block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get()));
00087 (++block_).Poison();
00088 }
00089
00090
00091 void StartSentence() {
00092 for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) {
00093 *i = kBOS;
00094 }
00095 }
00096
00097 void Append(WordIndex word) {
00098 *(gram_.end() - 1) = word;
00099 Dedupe::MutableIterator at;
00100 bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
00101 if (found) {
00102
00103 NGram<BuildingPayload> already(at->key, gram_.Order());
00104 ++(already.Value().count);
00105
00106 memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
00107 return;
00108 }
00109
00110 gram_.Value().count = 1;
00111
00112 if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
00113 NGram<BuildingPayload> last(gram_);
00114 gram_.NextInMemory();
00115 std::copy(last.begin() + 1, last.end(), gram_.begin());
00116 return;
00117 }
00118
00119 std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
00120 dedupe_.Clear();
00121 block_->SetValidSize(block_size_);
00122 gram_.ReBase((++block_)->Get());
00123 std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
00124 }
00125
00126 private:
00127 void AddUnigramWord(WordIndex index) {
00128 *gram_.begin() = index;
00129 gram_.Value().count = 0;
00130 gram_.NextInMemory();
00131 if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
00132 block_->SetValidSize(block_size_);
00133 gram_.ReBase((++block_)->Get());
00134 }
00135 }
00136
00137 util::stream::Link block_;
00138
00139 NGram<BuildingPayload> gram_;
00140
00141
00142 std::vector<WordIndex> dedupe_invalid_;
00143
00144 Dedupe dedupe_;
00145
00146
00147 boost::scoped_array<WordIndex> buffer_;
00148
00149 const std::size_t block_size_;
00150 };
00151
00152 }
00153
00154 float CorpusCount::DedupeMultiplier(std::size_t order) {
00155 return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram<BuildingPayload>::TotalSize(order));
00156 }
00157
00158 std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
00159 return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate);
00160 }
00161
00162 CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol)
00163 : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
00164 prune_words_(prune_words), prune_vocab_filename_(prune_vocab_filename),
00165 dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
00166 dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
00167 disallowed_symbol_action_(disallowed_symbol) {
00168 }
00169
00170 namespace {
00171 void ComplainDisallowed(StringPiece word, WarningAction &action) {
00172 switch (action) {
00173 case SILENT:
00174 return;
00175 case COMPLAIN:
00176 std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
00177 action = SILENT;
00178 return;
00179 case THROW_UP:
00180 UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
00181 }
00182 }
00183 }
00184
00185 void CorpusCount::Run(const util::stream::ChainPosition &position) {
00186 ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_);
00187 token_count_ = 0;
00188 type_count_ = 0;
00189 const WordIndex end_sentence = vocab.FindOrInsert("</s>");
00190 Writer writer(NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
00191 uint64_t count = 0;
00192 bool delimiters[256];
00193 util::BoolCharacter::Build("\0\t\n\r ", delimiters);
00194 try {
00195 while(true) {
00196 StringPiece line(from_.ReadLine());
00197 writer.StartSentence();
00198 for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
00199 WordIndex word = vocab.FindOrInsert(*w);
00200 if (word <= 2) {
00201 ComplainDisallowed(*w, disallowed_symbol_action_);
00202 continue;
00203 }
00204 writer.Append(word);
00205 ++count;
00206 }
00207 writer.Append(end_sentence);
00208 }
00209 } catch (const util::EndOfFileException &e) {}
00210 token_count_ = count;
00211 type_count_ = vocab.Size();
00212
00213
00214 if (!prune_vocab_filename_.empty()) {
00215 try {
00216 util::FilePiece prune_vocab_file(prune_vocab_filename_.c_str());
00217
00218 prune_words_.resize(vocab.Size(), true);
00219 try {
00220 while (true) {
00221 StringPiece word(prune_vocab_file.ReadDelimited(delimiters));
00222 prune_words_[vocab.Index(word)] = false;
00223 }
00224 } catch (const util::EndOfFileException &e) {}
00225
00226
00227 prune_words_[kUNK] = false;
00228 prune_words_[kBOS] = false;
00229 prune_words_[kEOS] = false;
00230
00231 } catch (const util::Exception &e) {
00232 std::cerr << e.what() << std::endl;
00233 abort();
00234 }
00235 }
00236 }
00237
00238 }
00239 }