00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #include <algorithm>
00022 #include <vector>
00023 #include "ChartHypothesis.h"
00024 #include "RuleCubeItem.h"
00025 #include "ChartCell.h"
00026 #include "ChartManager.h"
00027 #include "TargetPhrase.h"
00028 #include "Phrase.h"
00029 #include "StaticData.h"
00030 #include "ChartTranslationOptions.h"
00031 #include "moses/FF/FFState.h"
00032 #include "moses/FF/StatefulFeatureFunction.h"
00033 #include "moses/FF/StatelessFeatureFunction.h"
00034
00035 using namespace std;
00036
00037 namespace Moses
00038 {
00039
00045 ChartHypothesis::ChartHypothesis(const ChartTranslationOptions &transOpt,
00046 const RuleCubeItem &item,
00047 ChartManager &manager)
00048 :m_transOpt(item.GetTranslationDimension().GetTranslationOption())
00049 ,m_currSourceWordsRange(transOpt.GetSourceWordsRange())
00050 ,m_ffStates(StatefulFeatureFunction::GetStatefulFeatureFunctions().size())
00051 ,m_arcList(NULL)
00052 ,m_winningHypo(NULL)
00053 ,m_manager(manager)
00054 ,m_id(manager.GetNextHypoId())
00055 {
00056
00057 const std::vector<HypothesisDimension> &childEntries = item.GetHypothesisDimensions();
00058 m_prevHypos.reserve(childEntries.size());
00059 std::vector<HypothesisDimension>::const_iterator iter;
00060 for (iter = childEntries.begin(); iter != childEntries.end(); ++iter) {
00061 m_prevHypos.push_back(iter->GetHypothesis());
00062 }
00063 }
00064
00065
00066
00067 ChartHypothesis::ChartHypothesis(const ChartHypothesis &pred,
00068 const ChartKBestExtractor & )
00069 :m_currSourceWordsRange(pred.m_currSourceWordsRange)
00070 ,m_totalScore(pred.m_totalScore)
00071 ,m_arcList(NULL)
00072 ,m_winningHypo(NULL)
00073 ,m_manager(pred.m_manager)
00074 ,m_id(pred.m_manager.GetNextHypoId())
00075 {
00076
00077 m_prevHypos.push_back(&pred);
00078 }
00079
00080 ChartHypothesis::~ChartHypothesis()
00081 {
00082
00083 for (unsigned i = 0; i < m_ffStates.size(); ++i) {
00084 delete m_ffStates[i];
00085 }
00086
00087
00088 if (m_arcList) {
00089 ChartArcList::iterator iter;
00090 for (iter = m_arcList->begin() ; iter != m_arcList->end() ; ++iter) {
00091 ChartHypothesis *hypo = *iter;
00092 delete hypo;
00093 }
00094 m_arcList->clear();
00095
00096 delete m_arcList;
00097 }
00098 }
00099
00103 void ChartHypothesis::GetOutputPhrase(Phrase &outPhrase) const
00104 {
00105 FactorType placeholderFactor = StaticData::Instance().options()->input.placeholder_factor;
00106
00107 for (size_t pos = 0; pos < GetCurrTargetPhrase().GetSize(); ++pos) {
00108 const Word &word = GetCurrTargetPhrase().GetWord(pos);
00109 if (word.IsNonTerminal()) {
00110
00111 size_t nonTermInd = GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap()[pos];
00112 const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd];
00113 prevHypo->GetOutputPhrase(outPhrase);
00114 } else {
00115 outPhrase.AddWord(word);
00116
00117 if (placeholderFactor != NOT_FOUND) {
00118 std::set<size_t> sourcePosSet = GetCurrTargetPhrase().GetAlignTerm().GetAlignmentsForTarget(pos);
00119 if (sourcePosSet.size() == 1) {
00120 const std::vector<const Word*> *ruleSourceFromInputPath = GetTranslationOption().GetSourceRuleFromInputPath();
00121 UTIL_THROW_IF2(ruleSourceFromInputPath == NULL,
00122 "No source rule");
00123
00124 size_t sourcePos = *sourcePosSet.begin();
00125 const Word *sourceWord = ruleSourceFromInputPath->at(sourcePos);
00126 UTIL_THROW_IF2(sourceWord == NULL,
00127 "No source word");
00128 const Factor *factor = sourceWord->GetFactor(placeholderFactor);
00129 if (factor) {
00130 outPhrase.Back()[0] = factor;
00131 }
00132 }
00133 }
00134
00135 }
00136 }
00137 }
00138
00140 Phrase ChartHypothesis::GetOutputPhrase() const
00141 {
00142 Phrase outPhrase(ARRAY_SIZE_INCR);
00143 GetOutputPhrase(outPhrase);
00144 return outPhrase;
00145 }
00146
00148 void ChartHypothesis::GetOutputPhrase(size_t leftRightMost, size_t numWords, Phrase &outPhrase) const
00149 {
00150 const TargetPhrase &tp = GetCurrTargetPhrase();
00151
00152 size_t targetSize = tp.GetSize();
00153 for (size_t i = 0; i < targetSize; ++i) {
00154 size_t pos;
00155 if (leftRightMost == 1) {
00156 pos = i;
00157 } else if (leftRightMost == 2) {
00158 pos = targetSize - i - 1;
00159 } else {
00160 abort();
00161 }
00162
00163 const Word &word = tp.GetWord(pos);
00164
00165 if (word.IsNonTerminal()) {
00166
00167 size_t nonTermInd = tp.GetAlignNonTerm().GetNonTermIndexMap()[pos];
00168 const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd];
00169 prevHypo->GetOutputPhrase(outPhrase);
00170 } else {
00171 outPhrase.AddWord(word);
00172 }
00173
00174 if (outPhrase.GetSize() >= numWords) {
00175 return;
00176 }
00177 }
00178 }
00179
00181 void ChartHypothesis::EvaluateWhenApplied()
00182 {
00183 const StaticData &staticData = StaticData::Instance();
00184
00185
00186
00187 const std::vector<const StatelessFeatureFunction*>& sfs =
00188 StatelessFeatureFunction::GetStatelessFeatureFunctions();
00189 for (unsigned i = 0; i < sfs.size(); ++i) {
00190 if (! staticData.IsFeatureFunctionIgnored( *sfs[i] )) {
00191 sfs[i]->EvaluateWhenApplied(*this,&m_currScoreBreakdown);
00192 }
00193 }
00194
00195 const std::vector<const StatefulFeatureFunction*>& ffs =
00196 StatefulFeatureFunction::GetStatefulFeatureFunctions();
00197 for (unsigned i = 0; i < ffs.size(); ++i) {
00198 if (! staticData.IsFeatureFunctionIgnored( *ffs[i] )) {
00199 m_ffStates[i] = ffs[i]->EvaluateWhenApplied(*this,i,&m_currScoreBreakdown);
00200 }
00201 }
00202
00203
00204 m_totalScore = GetTranslationOption().GetScores().GetWeightedScore();
00205 m_totalScore += m_currScoreBreakdown.GetWeightedScore();
00206
00207
00208 for (std::vector<const ChartHypothesis*>::const_iterator iter = m_prevHypos.begin(); iter != m_prevHypos.end(); ++iter) {
00209 const ChartHypothesis &prevHypo = **iter;
00210 m_totalScore += prevHypo.GetFutureScore();
00211 }
00212 }
00213
00214 void ChartHypothesis::AddArc(ChartHypothesis *loserHypo)
00215 {
00216 if (!m_arcList) {
00217 if (loserHypo->m_arcList) {
00218
00219 this->m_arcList = loserHypo->m_arcList;
00220 loserHypo->m_arcList = 0;
00221 } else {
00222 this->m_arcList = new ChartArcList();
00223 }
00224 } else {
00225 if (loserHypo->m_arcList) {
00226
00227 size_t my_size = m_arcList->size();
00228 size_t add_size = loserHypo->m_arcList->size();
00229 this->m_arcList->resize(my_size + add_size, 0);
00230 std::memcpy(&(*m_arcList)[0] + my_size, &(*loserHypo->m_arcList)[0], add_size * sizeof(ChartHypothesis *));
00231 delete loserHypo->m_arcList;
00232 loserHypo->m_arcList = 0;
00233 } else {
00234
00235
00236 }
00237 }
00238 m_arcList->push_back(loserHypo);
00239 }
00240
00241
00242 struct CompareChartHypothesisTotalScore {
00243 bool operator()(const ChartHypothesis* hypo1, const ChartHypothesis* hypo2) const {
00244 return hypo1->GetFutureScore() > hypo2->GetFutureScore();
00245 }
00246 };
00247
00248 void ChartHypothesis::CleanupArcList()
00249 {
00250
00251 m_winningHypo = this;
00252
00253 if (!m_arcList) return;
00254
00255
00256
00257
00258
00259 AllOptions const& opts = *StaticData::Instance().options();
00260 size_t nBestSize = opts.nbest.nbest_size;
00261 bool distinctNBest = (opts.nbest.only_distinct
00262 || opts.mbr.enabled
00263 || opts.output.NeedSearchGraph()
00264 || !opts.output.SearchGraphHG.empty());
00265
00266 if (!distinctNBest && m_arcList->size() > nBestSize) {
00267
00268 NTH_ELEMENT4(m_arcList->begin()
00269 , m_arcList->begin() + nBestSize - 1
00270 , m_arcList->end()
00271 , CompareChartHypothesisTotalScore());
00272
00273
00274 ChartArcList::iterator iter;
00275 for (iter = m_arcList->begin() + nBestSize ; iter != m_arcList->end() ; ++iter) {
00276 ChartHypothesis *arc = *iter;
00277 delete arc;
00278 }
00279 m_arcList->erase(m_arcList->begin() + nBestSize
00280 , m_arcList->end());
00281 }
00282
00283
00284 ChartArcList::iterator iter = m_arcList->begin();
00285 for (; iter != m_arcList->end() ; ++iter) {
00286 ChartHypothesis *arc = *iter;
00287 arc->SetWinningHypo(this);
00288 }
00289
00290
00291 }
00292
00293 void ChartHypothesis::SetWinningHypo(const ChartHypothesis *hypo)
00294 {
00295 m_winningHypo = hypo;
00296 }
00297
00298 size_t ChartHypothesis::hash() const
00299 {
00300 size_t seed = 0;
00301
00302
00303 for (size_t i = 0; i < m_ffStates.size(); ++i) {
00304 const FFState *state = m_ffStates[i];
00305 size_t hash = state->hash();
00306 boost::hash_combine(seed, hash);
00307 }
00308 return seed;
00309
00310 }
00311
00312 bool ChartHypothesis::operator==(const ChartHypothesis& other) const
00313 {
00314
00315 for (size_t i = 0; i < m_ffStates.size(); ++i) {
00316 const FFState &thisState = *m_ffStates[i];
00317 const FFState &otherState = *other.m_ffStates[i];
00318 if (thisState != otherState) {
00319 return false;
00320 }
00321 }
00322 return true;
00323 }
00324
00325 TO_STRING_BODY(ChartHypothesis)
00326
00327
00328 std::ostream& operator<<(std::ostream& out, const ChartHypothesis& hypo)
00329 {
00330
00331 out << hypo.GetId();
00332
00333
00334 if (hypo.GetWinningHypothesis() != NULL &&
00335 hypo.GetWinningHypothesis() != &hypo) {
00336 out << "->" << hypo.GetWinningHypothesis()->GetId();
00337 }
00338
00339 if (hypo.GetManager().options()->output.include_lhs_in_search_graph) {
00340 out << " " << hypo.GetTargetLHS() << "=>";
00341 }
00342 out << " " << hypo.GetCurrTargetPhrase()
00343
00344 << " " << hypo.GetCurrSourceRange();
00345
00346 HypoList::const_iterator iter;
00347 for (iter = hypo.GetPrevHypos().begin(); iter != hypo.GetPrevHypos().end(); ++iter) {
00348 const ChartHypothesis &prevHypo = **iter;
00349 out << " " << prevHypo.GetId();
00350 }
00351
00352 out << " [total=" << hypo.GetFutureScore() << "]";
00353 out << " " << hypo.GetScoreBreakdown();
00354
00355
00356
00357 return out;
00358 }
00359
00360 }