00001 #ifndef LM_QUANTIZE_H
00002 #define LM_QUANTIZE_H
00003
00004 #include "lm/blank.hh"
00005 #include "lm/config.hh"
00006 #include "lm/max_order.hh"
00007 #include "lm/model_type.hh"
00008 #include "util/bit_packing.hh"
00009
00010 #include <algorithm>
00011 #include <vector>
00012
00013 #include <stdint.h>
00014
00015 #include <iostream>
00016
00017 namespace lm {
00018 namespace ngram {
00019
00020 struct Config;
00021 class BinaryFormat;
00022
00023
00024 class DontQuantize {
00025 public:
00026 static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
00027 static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}
00028 static uint64_t Size(uint8_t , const Config &) { return 0; }
00029 static uint8_t MiddleBits(const Config &) { return 63; }
00030 static uint8_t LongestBits(const Config &) { return 31; }
00031
00032 class MiddlePointer {
00033 public:
00034 MiddlePointer(const DontQuantize & , unsigned char , util::BitAddress address) : address_(address) {}
00035
00036 MiddlePointer() : address_(NULL, 0) {}
00037
00038 bool Found() const {
00039 return address_.base != NULL;
00040 }
00041
00042 float Prob() const {
00043 return util::ReadNonPositiveFloat31(address_.base, address_.offset);
00044 }
00045
00046 float Backoff() const {
00047 return util::ReadFloat32(address_.base, address_.offset + 31);
00048 }
00049
00050 float Rest() const { return Prob(); }
00051
00052 void Write(float prob, float backoff) {
00053 util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
00054 util::WriteFloat32(address_.base, address_.offset + 31, backoff);
00055 }
00056
00057 private:
00058 util::BitAddress address_;
00059 };
00060
00061 class LongestPointer {
00062 public:
00063 explicit LongestPointer(const DontQuantize &, util::BitAddress address) : address_(address) {}
00064
00065 LongestPointer() : address_(NULL, 0) {}
00066
00067 bool Found() const {
00068 return address_.base != NULL;
00069 }
00070
00071 float Prob() const {
00072 return util::ReadNonPositiveFloat31(address_.base, address_.offset);
00073 }
00074
00075 void Write(float prob) {
00076 util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
00077 }
00078
00079 private:
00080 util::BitAddress address_;
00081 };
00082
00083 DontQuantize() {}
00084
00085 void SetupMemory(void * , unsigned char , const Config & ) {}
00086
00087 static const bool kTrain = false;
00088
00089 void Train(uint8_t , std::vector<float> &, std::vector<float> &) {}
00090 void TrainProb(uint8_t, std::vector<float> &) {}
00091
00092 void FinishedLoading(const Config &) {}
00093 };
00094
00095 class SeparatelyQuantize {
00096 private:
00097 class Bins {
00098 public:
00099
00100 Bins() {}
00101
00102 Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
00103
00104 float *Populate() { return begin_; }
00105
00106 uint64_t EncodeProb(float value) const {
00107 return Encode(value, 0);
00108 }
00109
00110 uint64_t EncodeBackoff(float value) const {
00111 if (value == 0.0) {
00112 return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant;
00113 }
00114 return Encode(value, 2);
00115 }
00116
00117 float Decode(std::size_t off) const { return begin_[off]; }
00118
00119 uint8_t Bits() const { return bits_; }
00120
00121 uint64_t Mask() const { return mask_; }
00122
00123 private:
00124 uint64_t Encode(float value, size_t reserved) const {
00125 const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value);
00126 if (above == begin_ + reserved) return reserved;
00127 if (above == end_) return end_ - begin_ - 1;
00128 return above - begin_ - (value - *(above - 1) < *above - value);
00129 }
00130
00131 float *begin_;
00132 const float *end_;
00133 uint8_t bits_;
00134 uint64_t mask_;
00135 };
00136
00137 public:
00138 static const ModelType kModelTypeAdd = kQuantAdd;
00139
00140 static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
00141
00142 static uint64_t Size(uint8_t order, const Config &config) {
00143 uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
00144 uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table;
00145
00146 return (order - 2) * middle_table + longest_table + 8;
00147 }
00148
00149 static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; }
00150 static uint8_t LongestBits(const Config &config) { return config.prob_bits; }
00151
00152 class MiddlePointer {
00153 public:
00154 MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {}
00155
00156 MiddlePointer() : address_(NULL, 0) {}
00157
00158 bool Found() const { return address_.base != NULL; }
00159
00160 float Prob() const {
00161 return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask()));
00162 }
00163
00164 float Backoff() const {
00165 return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask()));
00166 }
00167
00168 float Rest() const { return Prob(); }
00169
00170 void Write(float prob, float backoff) const {
00171 util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(),
00172 (ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff));
00173 }
00174
00175 private:
00176 const Bins &ProbBins() const { return bins_[0]; }
00177 const Bins &BackoffBins() const { return bins_[1]; }
00178 const Bins *bins_;
00179
00180 util::BitAddress address_;
00181 };
00182
00183 class LongestPointer {
00184 public:
00185 LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {}
00186
00187 LongestPointer() : address_(NULL, 0) {}
00188
00189 bool Found() const { return address_.base != NULL; }
00190
00191 void Write(float prob) const {
00192 util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob));
00193 }
00194
00195 float Prob() const {
00196 return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask()));
00197 }
00198
00199 private:
00200 const Bins *table_;
00201 util::BitAddress address_;
00202 };
00203
00204 SeparatelyQuantize() {}
00205
00206 void SetupMemory(void *start, unsigned char order, const Config &config);
00207
00208 static const bool kTrain = true;
00209
00210 void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff);
00211
00212 void TrainProb(uint8_t order, std::vector<float> &prob);
00213
00214 void FinishedLoading(const Config &config);
00215
00216 const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; }
00217
00218 const Bins &LongestTable() const { return longest_; }
00219
00220 private:
00221 Bins tables_[KENLM_MAX_ORDER - 1][2];
00222
00223 Bins longest_;
00224
00225 uint8_t *actual_base_;
00226
00227 uint8_t prob_bits_, backoff_bits_;
00228 };
00229
00230 }
00231 }
00232
00233 #endif // LM_QUANTIZE_H