00001 #include "lm/common/model_buffer.hh"
00002 #include "util/exception.hh"
00003 #include "util/file_stream.hh"
00004 #include "util/file.hh"
00005 #include "util/file_piece.hh"
00006 #include "util/stream/io.hh"
00007 #include "util/stream/multi_stream.hh"
00008
00009 #include <boost/lexical_cast.hpp>
00010
00011 namespace lm {
00012
00013 namespace {
00014 const char kMetadataHeader[] = "KenLM intermediate binary file";
00015 }
00016
00017 ModelBuffer::ModelBuffer(StringPiece file_base, bool keep_buffer, bool output_q)
00018 : file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer), output_q_(output_q),
00019 vocab_file_(keep_buffer ? util::CreateOrThrow((file_base_ + ".vocab").c_str()) : util::MakeTemp(file_base_)) {}
00020
00021 ModelBuffer::ModelBuffer(StringPiece file_base)
00022 : file_base_(file_base.data(), file_base.size()), keep_buffer_(false) {
00023 const std::string full_name = file_base_ + ".kenlm_intermediate";
00024 util::FilePiece in(full_name.c_str());
00025 StringPiece token = in.ReadLine();
00026 UTIL_THROW_IF2(token != kMetadataHeader, "File " << full_name << " begins with \"" << token << "\" not " << kMetadataHeader);
00027
00028 token = in.ReadDelimited();
00029 UTIL_THROW_IF2(token != "Counts", "Expected Counts, got \"" << token << "\" in " << full_name);
00030 char got;
00031 while ((got = in.get()) == ' ') {
00032 counts_.push_back(in.ReadULong());
00033 }
00034 UTIL_THROW_IF2(got != '\n', "Expected newline at end of counts.");
00035
00036 token = in.ReadDelimited();
00037 UTIL_THROW_IF2(token != "Payload", "Expected Payload, got \"" << token << "\" in " << full_name);
00038 token = in.ReadDelimited();
00039 if (token == "q") {
00040 output_q_ = true;
00041 } else if (token == "pb") {
00042 output_q_ = false;
00043 } else {
00044 UTIL_THROW(util::Exception, "Unknown payload " << token);
00045 }
00046
00047 vocab_file_.reset(util::OpenReadOrThrow((file_base_ + ".vocab").c_str()));
00048
00049 files_.Init(counts_.size());
00050 for (unsigned long i = 0; i < counts_.size(); ++i) {
00051 files_.push_back(util::OpenReadOrThrow((file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()));
00052 }
00053 }
00054
00055 void ModelBuffer::Sink(util::stream::Chains &chains, const std::vector<uint64_t> &counts) {
00056 counts_ = counts;
00057
00058 files_.Init(chains.size());
00059 for (std::size_t i = 0; i < chains.size(); ++i) {
00060 if (keep_buffer_) {
00061 files_.push_back(util::CreateOrThrow(
00062 (file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()
00063 ));
00064 } else {
00065 files_.push_back(util::MakeTemp(file_base_));
00066 }
00067 chains[i] >> util::stream::Write(files_.back().get());
00068 }
00069 if (keep_buffer_) {
00070 util::scoped_fd metadata(util::CreateOrThrow((file_base_ + ".kenlm_intermediate").c_str()));
00071 util::FileStream meta(metadata.get(), 200);
00072 meta << kMetadataHeader << "\nCounts";
00073 for (std::vector<uint64_t>::const_iterator i = counts_.begin(); i != counts_.end(); ++i) {
00074 meta << ' ' << *i;
00075 }
00076 meta << "\nPayload " << (output_q_ ? "q" : "pb") << '\n';
00077 }
00078 }
00079
00080 void ModelBuffer::Source(util::stream::Chains &chains) {
00081 assert(chains.size() <= files_.size());
00082 for (unsigned int i = 0; i < chains.size(); ++i) {
00083 chains[i] >> util::stream::PRead(files_[i].get());
00084 }
00085 }
00086
00087 void ModelBuffer::Source(std::size_t order_minus_1, util::stream::Chain &chain) {
00088 chain >> util::stream::PRead(files_[order_minus_1].get());
00089 }
00090
00091 }