00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifdef LM_RAND
00011
00012 #include "LM/Base.h"
00013 #include "LM/LDHT.h"
00014 #include "moses/FFState.h"
00015 #include "moses/TypeDef.h"
00016 #include "moses/Hypothesis.h"
00017 #include "moses/StaticData.h"
00018 #include "util/exception.hh"
00019
00020 #include <LDHT/Client.h>
00021 #include <LDHT/ClientLocal.h>
00022 #include <LDHT/NewNgram.h>
00023 #include <LDHT/FactoryCollection.h>
00024
00025 #include <boost/thread/tss.hpp>
00026
00027 namespace Moses
00028 {
00029
00030 struct LDHTLMState : public FFState {
00031 LDHT::NewNgram gram_fingerprints;
00032 bool finalised;
00033 std::vector<int> request_tags;
00034
00035 LDHTLMState(): finalised(false) {
00036 }
00037
00038 void setFinalised() {
00039 this->finalised = true;
00040 }
00041
00042 void appendRequestTag(int tag) {
00043 this->request_tags.push_back(tag);
00044 }
00045
00046 void clearRequestTags() {
00047 this->request_tags.clear();
00048 }
00049
00050 std::vector<int>::iterator requestTagsBegin() {
00051 return this->request_tags.begin();
00052 }
00053
00054 std::vector<int>::iterator requestTagsEnd() {
00055 return this->request_tags.end();
00056 }
00057
00058 int Compare(const FFState& uncast_other) const {
00059 const LDHTLMState &other = static_cast<const LDHTLMState&>(uncast_other);
00060
00061
00062
00063 return gram_fingerprints.compareMoses(other.gram_fingerprints);
00064 }
00065
00066 void copyFrom(const LDHTLMState& other) {
00067 gram_fingerprints.copyFrom(other.gram_fingerprints);
00068 finalised = false;
00069 }
00070 };
00071
00072 class LanguageModelLDHT : public LanguageModel
00073 {
00074 public:
00075 LanguageModelLDHT();
00076 LanguageModelLDHT(const std::string& path,
00077 ScoreIndexManager& manager,
00078 FactorType factorType);
00079 LanguageModelLDHT(ScoreIndexManager& manager,
00080 LanguageModelLDHT& copyFrom);
00081
00082 LDHT::Client* getClientUnsafe() const;
00083 LDHT::Client* getClientSafe();
00084 LDHT::Client* initTSSClient();
00085 virtual ~LanguageModelLDHT();
00086 virtual void InitializeForInput(ttasksptr const& ttask);
00087 virtual void CleanUpAfterSentenceProcessing(const InputType &source);
00088 virtual const FFState* EmptyHypothesisState(const InputType& input) const;
00089 virtual void CalcScore(const Phrase& phrase,
00090 float& fullScore,
00091 float& ngramScore,
00092 std::size_t& oovCount) const;
00093 virtual void CalcScoreFromCache(const Phrase& phrase,
00094 float& fullScore,
00095 float& ngramScore,
00096 std::size_t& oovCount) const;
00097 FFState* Evaluate(const Hypothesis& hypo,
00098 const FFState* input_state,
00099 ScoreComponentCollection* score_output) const;
00100 FFState* EvaluateWhenApplied(const ChartHypothesis& hypo,
00101 int featureID,
00102 ScoreComponentCollection* accumulator) const;
00103
00104 virtual void IssueRequestsFor(Hypothesis& hypo,
00105 const FFState* input_state);
00106 float calcScoreFromState(LDHTLMState* hypo) const;
00107 void sync();
00108 void SetFFStateIdx(int state_idx);
00109
00110 protected:
00111 boost::thread_specific_ptr<LDHT::Client> m_client;
00112 std::string m_configPath;
00113 FactorType m_factorType;
00114 int m_state_idx;
00115 int m_calc_score_count;
00116 uint64_t m_start_tick;
00117
00118 };
00119
00120 LanguageModel* ConstructLDHTLM(const std::string& path,
00121 ScoreIndexManager& manager,
00122 FactorType factorType)
00123 {
00124 return new LanguageModelLDHT(path, manager, factorType);
00125 }
00126
00127 LanguageModelLDHT::LanguageModelLDHT() : LanguageModel(), m_client(NULL)
00128 {
00129 m_enableOOVFeature = false;
00130 }
00131
00132 LanguageModelLDHT::LanguageModelLDHT(ScoreIndexManager& manager,
00133 LanguageModelLDHT& copyFrom)
00134 {
00135 m_calc_score_count = 0;
00136
00137 m_factorType = copyFrom.m_factorType;
00138 m_configPath = copyFrom.m_configPath;
00139 Init(manager);
00140 }
00141
00142 LanguageModelLDHT::LanguageModelLDHT(const std::string& path,
00143 ScoreIndexManager& manager,
00144 FactorType factorType)
00145 : m_factorType(factorType)
00146 {
00147 m_configPath = path;
00148 Init(manager);
00149 }
00150
00151 LanguageModelLDHT::~LanguageModelLDHT()
00152 {
00153
00154
00155 }
00156
00157
00158
00159 LDHT::Client* LanguageModelLDHT::getClientSafe()
00160 {
00161 if (m_client.get() == NULL)
00162 m_client.reset(initTSSClient());
00163 return m_client.get();
00164 }
00165
00166
00167 LDHT::Client* LanguageModelLDHT::getClientUnsafe() const
00168 {
00169 return m_client.get();
00170 }
00171
00172 LDHT::Client* LanguageModelLDHT::initTSSClient()
00173 {
00174 std::ifstream config_file(m_configPath.c_str());
00175 std::string ldht_config_path;
00176 getline(config_file, ldht_config_path);
00177 std::string ldhtlm_config_path;
00178 getline(config_file, ldhtlm_config_path);
00179
00180 LDHT::FactoryCollection* factory_collection =
00181 LDHT::FactoryCollection::createDefaultFactoryCollection();
00182
00183 LDHT::Client* client;
00184
00185 client = new LDHT::Client();
00186 client->fromXmlFiles(*factory_collection,
00187 ldht_config_path,
00188 ldhtlm_config_path);
00189 return client;
00190 }
00191
00192 void LanguageModelLDHT::InitializeForInput(ttasksptr const& ttask)
00193 {
00194 getClientSafe()->clearCache();
00195 m_start_tick = LDHT::Util::rdtsc();
00196 }
00197
00198 void LanguageModelLDHT::CleanUpAfterSentenceProcessing(const InputType &source)
00199 {
00200 LDHT::Client* client = getClientSafe();
00201
00202 std::cerr << "LDHT sentence stats:" << std::endl;
00203 std::cerr << " ngrams submitted: " << client->getNumNgramsSubmitted() << std::endl
00204 << " ngrams requested: " << client->getNumNgramsRequested() << std::endl
00205 << " ngrams not found: " << client->getKeyNotFoundCount() << std::endl
00206 << " cache hits: " << client->getCacheHitCount() << std::endl
00207 << " inferences: " << client->getInferenceCount() << std::endl
00208 << " pcnt latency: " << (float)client->getLatencyTicks() / (float)(LDHT::Util::rdtsc() - m_start_tick) * 100.0 << std::endl;
00209 m_start_tick = 0;
00210 client->resetLatencyTicks();
00211 client->resetNumNgramsSubmitted();
00212 client->resetNumNgramsRequested();
00213 client->resetInferenceCount();
00214 client->resetCacheHitCount();
00215 client->resetKeyNotFoundCount();
00216 }
00217
00218 const FFState* LanguageModelLDHT::EmptyHypothesisState(
00219 const InputType& input) const
00220 {
00221 return NULL;
00222 }
00223
00224 void LanguageModelLDHT::CalcScore(const Phrase& phrase,
00225 float& fullScore,
00226 float& ngramScore,
00227 std::size_t& oovCount) const
00228 {
00229 const_cast<LanguageModelLDHT*>(this)->m_calc_score_count++;
00230 if (m_calc_score_count > 10000) {
00231 const_cast<LanguageModelLDHT*>(this)->m_calc_score_count = 0;
00232 const_cast<LanguageModelLDHT*>(this)->sync();
00233 }
00234
00235
00236 LDHT::Client* client = getClientUnsafe();
00237
00238 int order = LDHT::NewNgram::k_max_order;
00239 int prefix_start = 0;
00240 int prefix_end = std::min(phrase.GetSize(), static_cast<size_t>(order - 1));
00241 LDHT::NewNgram ngram;
00242 for (int word_idx = prefix_start; word_idx < prefix_end; ++word_idx) {
00243 ngram.appendGram(phrase.GetWord(word_idx)
00244 .GetFactor(m_factorType)->GetString().c_str());
00245 client->requestNgram(ngram);
00246 }
00247
00248 int internal_start = prefix_end;
00249 int internal_end = phrase.GetSize();
00250 for (int word_idx = internal_start; word_idx < internal_end; ++word_idx) {
00251 ngram.appendGram(phrase.GetWord(word_idx)
00252 .GetFactor(m_factorType)->GetString().c_str());
00253 client->requestNgram(ngram);
00254 }
00255
00256 fullScore = 0;
00257 ngramScore = 0;
00258 oovCount = 0;
00259 }
00260
00261 void LanguageModelLDHT::CalcScoreFromCache(const Phrase& phrase,
00262 float& fullScore,
00263 float& ngramScore,
00264 std::size_t& oovCount) const
00265 {
00266
00267
00268 const_cast<LanguageModelLDHT*>(this)->sync();
00269
00270
00271 LDHT::Client* client = getClientUnsafe();
00272
00273 int order = LDHT::NewNgram::k_max_order;
00274 int prefix_start = 0;
00275 int prefix_end = std::min(phrase.GetSize(), static_cast<size_t>(order - 1));
00276 LDHT::NewNgram ngram;
00277 std::deque<int> full_score_tags;
00278 for (int word_idx = prefix_start; word_idx < prefix_end; ++word_idx) {
00279 ngram.appendGram(phrase.GetWord(word_idx)
00280 .GetFactor(m_factorType)->GetString().c_str());
00281 full_score_tags.push_back(client->requestNgram(ngram));
00282 }
00283
00284 int internal_start = prefix_end;
00285 int internal_end = phrase.GetSize();
00286 std::deque<int> internal_score_tags;
00287 for (int word_idx = internal_start; word_idx < internal_end; ++word_idx) {
00288 ngram.appendGram(phrase.GetWord(word_idx)
00289 .GetFactor(m_factorType)->GetString().c_str());
00290 internal_score_tags.push_back(client->requestNgram(ngram));
00291 }
00292
00293
00294
00295
00296
00297 fullScore = 0.0;
00298 while (!full_score_tags.empty()) {
00299 fullScore += client->getNgramScore(full_score_tags.front());
00300 full_score_tags.pop_front();
00301 }
00302 ngramScore = 0.0;
00303 while (!internal_score_tags.empty()) {
00304 float score = client->getNgramScore(internal_score_tags.front());
00305 internal_score_tags.pop_front();
00306 fullScore += score;
00307 ngramScore += score;
00308 }
00309 fullScore = TransformLMScore(fullScore);
00310 ngramScore = TransformLMScore(ngramScore);
00311 oovCount = 0;
00312 }
00313
00314 void LanguageModelLDHT::IssueRequestsFor(Hypothesis& hypo,
00315 const FFState* input_state)
00316 {
00317
00318 LDHT::Client* client = getClientUnsafe();
00319
00320
00321
00322 LDHTLMState* new_state = new LDHTLMState();
00323 if (input_state == NULL) {
00324 if (hypo.GetCurrTargetWordsRange().GetStartPos() != 0) {
00325 UTIL_THROW2("got a null state but not at start of sentence");
00326 }
00327 new_state->gram_fingerprints.appendGram(BOS_);
00328 } else {
00329 if (hypo.GetCurrTargetWordsRange().GetStartPos() == 0) {
00330 UTIL_THROW2("got a non null state but at start of sentence");
00331 }
00332 new_state->copyFrom(static_cast<const LDHTLMState&>(*input_state));
00333 }
00334
00335
00336 int order = LDHT::NewNgram::k_max_order;
00337 int phrase_start = hypo.GetCurrTargetWordsRange().GetStartPos();
00338 int phrase_end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
00339 int overlap_start = phrase_start;
00340 int overlap_end = std::min(phrase_end, phrase_start + order - 1);
00341 int word_idx = overlap_start;
00342 LDHT::NewNgram& ngram = new_state->gram_fingerprints;
00343 for (; word_idx < overlap_end; ++word_idx) {
00344 ngram.appendGram(
00345 hypo.GetFactor(word_idx, m_factorType)->GetString().c_str());
00346 new_state->appendRequestTag(client->requestNgram(ngram));
00347 }
00348
00349
00350
00351 for (; word_idx < phrase_end; ++word_idx) {
00352 ngram.appendGram(
00353 hypo.GetFactor(word_idx, m_factorType)->GetString().c_str());
00354 }
00355
00356
00357 if (hypo.IsSourceCompleted()) {
00358 ngram.appendGram(EOS_);
00359
00360 new_state->appendRequestTag(client->requestNgram(ngram));
00361 }
00362 hypo.SetFFState(m_state_idx, new_state);
00363 }
00364
00365 void LanguageModelLDHT::sync()
00366 {
00367 m_calc_score_count = 0;
00368 getClientUnsafe()->awaitResponses();
00369 }
00370
00371 void LanguageModelLDHT::SetFFStateIdx(int state_idx)
00372 {
00373 m_state_idx = state_idx;
00374 }
00375
00376 FFState* LanguageModelLDHT::Evaluate(
00377 const Hypothesis& hypo,
00378 const FFState* input_state_ignored,
00379 ScoreComponentCollection* score_output) const
00380 {
00381
00382
00383
00384
00385
00386 LDHTLMState* state = const_cast<LDHTLMState*>(static_cast<const LDHTLMState*>(hypo.GetFFState(m_state_idx)));
00387
00388 float score = calcScoreFromState(state);
00389 score = FloorScore(TransformLMScore(score));
00390 score_output->PlusEquals(this, score);
00391
00392 return state;
00393 }
00394
00395 FFState* LanguageModelLDHT::EvaluateWhenApplied(
00396 const ChartHypothesis& hypo,
00397 int featureID,
00398 ScoreComponentCollection* accumulator) const
00399 {
00400 return NULL;
00401 }
00402
00403 float LanguageModelLDHT::calcScoreFromState(LDHTLMState* state) const
00404 {
00405 float score = 0.0;
00406 std::vector<int>::iterator tag_iter;
00407 LDHT::Client* client = getClientUnsafe();
00408 for (tag_iter = state->requestTagsBegin();
00409 tag_iter != state->requestTagsEnd();
00410 ++tag_iter) {
00411 score += client->getNgramScore(*tag_iter);
00412 }
00413 state->clearRequestTags();
00414 state->setFinalised();
00415 return score;
00416 }
00417
00418 }
00419
00420 #endif