00001 #include "lm/trie_sort.hh"
00002
00003 #include "lm/config.hh"
00004 #include "lm/lm_exception.hh"
00005 #include "lm/read_arpa.hh"
00006 #include "lm/vocab.hh"
00007 #include "lm/weights.hh"
00008 #include "lm/word_index.hh"
00009 #include "util/file_piece.hh"
00010 #include "util/mmap.hh"
00011 #include "util/proxy_iterator.hh"
00012 #include "util/sized_iterator.hh"
00013
00014 #include <algorithm>
00015 #include <cstring>
00016 #include <cstdio>
00017 #include <cstdlib>
00018 #include <deque>
00019 #include <iterator>
00020 #include <limits>
00021 #include <vector>
00022
00023 namespace lm {
00024 namespace ngram {
00025 namespace trie {
00026 namespace {
00027
00028 typedef util::SizedIterator NGramIter;
00029
00030
00031 class PartialViewProxy {
00032 public:
00033 PartialViewProxy() : attention_size_(0), inner_() {}
00034
00035 PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {}
00036
00037 operator std::string() const {
00038 return std::string(reinterpret_cast<const char*>(inner_.Data()), attention_size_);
00039 }
00040
00041 PartialViewProxy &operator=(const PartialViewProxy &from) {
00042 memcpy(inner_.Data(), from.inner_.Data(), attention_size_);
00043 return *this;
00044 }
00045
00046 PartialViewProxy &operator=(const std::string &from) {
00047 memcpy(inner_.Data(), from.data(), attention_size_);
00048 return *this;
00049 }
00050
00051 const void *Data() const { return inner_.Data(); }
00052 void *Data() { return inner_.Data(); }
00053
00054 friend void swap(PartialViewProxy first, PartialViewProxy second) {
00055 std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data()));
00056 }
00057
00058 private:
00059 friend class util::ProxyIterator<PartialViewProxy>;
00060
00061 typedef std::string value_type;
00062
00063 const std::size_t attention_size_;
00064
00065 typedef util::SizedInnerIterator InnerIterator;
00066 InnerIterator &Inner() { return inner_; }
00067 const InnerIterator &Inner() const { return inner_; }
00068 InnerIterator inner_;
00069 };
00070
00071 typedef util::ProxyIterator<PartialViewProxy> PartialIter;
00072
00073 FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) {
00074 util::scoped_fd file(util::MakeTemp(temp_prefix));
00075 util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
00076 return util::FDOpenOrThrow(file);
00077 }
00078
00079 FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) {
00080 const size_t context_size = sizeof(WordIndex) * (order - 1);
00081
00082 PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
00083 PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
00084
00085 #if defined(_WIN32) || defined(_WIN64)
00086 std::stable_sort
00087 #else
00088 std::sort
00089 #endif
00090 (context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
00091
00092 util::scoped_FILE out(util::FMakeTemp(temp_prefix));
00093
00094
00095 if (context_begin == context_end) return out.release();
00096 PartialIter i(context_begin);
00097 util::WriteOrThrow(out.get(), i->Data(), context_size);
00098 const void *previous = i->Data();
00099 ++i;
00100 for (; i != context_end; ++i) {
00101 if (memcmp(previous, i->Data(), context_size)) {
00102 util::WriteOrThrow(out.get(), i->Data(), context_size);
00103 previous = i->Data();
00104 }
00105 }
00106 return out.release();
00107 }
00108
00109 struct ThrowCombine {
00110 void operator()(std::size_t entry_size, unsigned char order, const void *first, const void *second, FILE * ) const {
00111 const WordIndex *base = reinterpret_cast<const WordIndex*>(first);
00112 FormatLoadException e;
00113 e << "Duplicate n-gram detected with vocab ids";
00114 for (const WordIndex *i = base; i != base + order; ++i) {
00115 e << ' ' << *i;
00116 }
00117 throw e;
00118 }
00119 };
00120
00121
00122 struct FirstCombine {
00123 void operator()(std::size_t entry_size, unsigned char , const void *first, const void * , FILE *out) const {
00124 util::WriteOrThrow(out, first, entry_size);
00125 }
00126 };
00127
00128 template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) {
00129 std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
00130 RecordReader first, second;
00131 first.Init(first_file, entry_size);
00132 second.Init(second_file, entry_size);
00133 util::scoped_FILE out_file(util::FMakeTemp(temp_prefix));
00134 EntryCompare less(order);
00135 while (first && second) {
00136 if (less(first.Data(), second.Data())) {
00137 util::WriteOrThrow(out_file.get(), first.Data(), entry_size);
00138 ++first;
00139 } else if (less(second.Data(), first.Data())) {
00140 util::WriteOrThrow(out_file.get(), second.Data(), entry_size);
00141 ++second;
00142 } else {
00143 combine(entry_size, order, first.Data(), second.Data(), out_file.get());
00144 ++first; ++second;
00145 }
00146 }
00147 for (RecordReader &remains = (first ? first : second); remains; ++remains) {
00148 util::WriteOrThrow(out_file.get(), remains.Data(), entry_size);
00149 }
00150 return out_file.release();
00151 }
00152
00153 }
00154
00155 void RecordReader::Init(FILE *file, std::size_t entry_size) {
00156 entry_size_ = entry_size;
00157 data_.reset(malloc(entry_size));
00158 UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer");
00159 file_ = file;
00160 if (file) {
00161 rewind(file);
00162 remains_ = true;
00163 ++*this;
00164 } else {
00165 remains_ = false;
00166 }
00167 }
00168
00169 void RecordReader::Overwrite(const void *start, std::size_t amount) {
00170 long internal = (uint8_t*)start - (uint8_t*)data_.get();
00171 UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
00172 util::WriteOrThrow(file_, start, amount);
00173 long forward = entry_size_ - internal - amount;
00174 #if !defined(_WIN32) && !defined(_WIN64)
00175 if (forward)
00176 #endif
00177 UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
00178 }
00179
00180 void RecordReader::Rewind() {
00181 if (file_) {
00182 rewind(file_);
00183 remains_ = true;
00184 ++*this;
00185 } else {
00186 remains_ = false;
00187 }
00188 }
00189
00190 SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
00191 PositiveProbWarn warn(config.positive_log_probability);
00192 unigram_.reset(util::MakeTemp(file_prefix));
00193 {
00194
00195 size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff);
00196 util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_.get(), size_out), size_out);
00197 Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
00198 CheckSpecials(config, vocab);
00199 if (!vocab.SawUnk()) ++counts[0];
00200 }
00201
00202
00203 size_t buffer_use = 0;
00204 for (unsigned int order = 2; order < counts.size(); ++order) {
00205 buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
00206 }
00207 buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
00208 buffer = std::min<size_t>(buffer, buffer_use);
00209
00210 util::scoped_malloc mem;
00211 mem.reset(malloc(buffer));
00212 if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
00213
00214 for (unsigned char order = 2; order <= counts.size(); ++order) {
00215 ConvertToSorted(f, vocab, counts, file_prefix, order, warn, mem.get(), buffer);
00216 }
00217 ReadEnd(f);
00218 }
00219
00220 namespace {
00221 class Closer {
00222 public:
00223 explicit Closer(std::deque<FILE*> &files) : files_(files) {}
00224
00225 ~Closer() {
00226 for (std::deque<FILE*>::iterator i = files_.begin(); i != files_.end(); ++i) {
00227 util::scoped_FILE deleter(*i);
00228 }
00229 }
00230
00231 void PopFront() {
00232 util::scoped_FILE deleter(files_.front());
00233 files_.pop_front();
00234 }
00235 private:
00236 std::deque<FILE*> &files_;
00237 };
00238 }
00239
00240 void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
00241 ReadNGramHeader(f, order);
00242 const size_t count = counts[order - 1];
00243
00244 const size_t words_size = sizeof(WordIndex) * order;
00245 const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
00246 const size_t entry_size = words_size + weights_size;
00247 const size_t batch_size = std::min(count, mem_size / entry_size);
00248 uint8_t *const begin = reinterpret_cast<uint8_t*>(mem);
00249
00250 std::deque<FILE*> files, contexts;
00251 Closer files_closer(files), contexts_closer(contexts);
00252
00253 for (std::size_t batch = 0, done = 0; done < count; ++batch) {
00254 uint8_t *out = begin;
00255 uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
00256 if (order == counts.size()) {
00257 for (; out != out_end; out += entry_size) {
00258 std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
00259 ReadNGram(f, order, vocab, it, *reinterpret_cast<Prob*>(out + words_size), warn);
00260 }
00261 } else {
00262 for (; out != out_end; out += entry_size) {
00263 std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
00264 ReadNGram(f, order, vocab, it, *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
00265 }
00266 }
00267
00268 util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
00269
00270 #if defined(_WIN32) || defined(_WIN64)
00271 std::stable_sort
00272 #else
00273 std::sort
00274 #endif
00275 (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order)));
00276 files.push_back(DiskFlush(begin, out_end, file_prefix));
00277 contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order));
00278
00279 done += (out_end - begin) / entry_size;
00280 }
00281
00282
00283
00284 while (files.size() > 1) {
00285 files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine()));
00286 files_closer.PopFront();
00287 files_closer.PopFront();
00288 contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine()));
00289 contexts_closer.PopFront();
00290 contexts_closer.PopFront();
00291 }
00292
00293 if (!files.empty()) {
00294
00295 full_[order - 2].reset(files.front());
00296 files.pop_front();
00297 context_[order - 2].reset(contexts.front());
00298 contexts.pop_front();
00299 }
00300 }
00301
00302 }
00303 }
00304 }