00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <algorithm>
00021 #include <iostream>
00022 #include <vector>
00023 #include "StaticData.h"
00024 #include "ChartTranslationOptionList.h"
00025 #include "ChartTranslationOptions.h"
00026 #include "ChartCellCollection.h"
00027 #include "Range.h"
00028 #include "InputType.h"
00029 #include "InputPath.h"
00030
00031 using namespace std;
00032
00033 namespace Moses
00034 {
00035
00036 ChartTranslationOptionList::
00037 ChartTranslationOptionList(size_t ruleLimit, const InputType &input)
00038 : m_size(0)
00039 , m_ruleLimit(ruleLimit)
00040 {
00041 m_scoreThreshold = std::numeric_limits<float>::infinity();
00042 }
00043
00044 ChartTranslationOptionList::~ChartTranslationOptionList()
00045 {
00046 RemoveAllInColl(m_collection);
00047 }
00048
00049 void ChartTranslationOptionList::Clear()
00050 {
00051 m_size = 0;
00052 m_scoreThreshold = std::numeric_limits<float>::infinity();
00053 }
00054
00055 class ChartTranslationOptionOrderer
00056 {
00057 public:
00058 bool operator()(const ChartTranslationOptions* itemA, const ChartTranslationOptions* itemB) const {
00059 return itemA->GetEstimateOfBestScore() > itemB->GetEstimateOfBestScore();
00060 }
00061 };
00062
00063 void ChartTranslationOptionList::Add(const TargetPhraseCollection &tpc,
00064 const StackVec &stackVec,
00065 const Range &range)
00066 {
00067 if (tpc.IsEmpty()) {
00068 return;
00069 }
00070
00071 for (size_t i = 0; i < stackVec.size(); ++i) {
00072 const ChartCellLabel &chartCellLabel = *stackVec[i];
00073 size_t numHypos = chartCellLabel.GetStack().cube->size();
00074 if (numHypos == 0) {
00075 return;
00076 }
00077 }
00078
00079 const TargetPhrase &targetPhrase = **(tpc.begin());
00080 float score = targetPhrase.GetFutureScore();
00081 for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) {
00082 score += (*p)->GetBestScore(this);
00083 }
00084
00085
00086
00087 if (m_ruleLimit && m_size > m_ruleLimit && score < m_scoreThreshold) {
00088 return;
00089 }
00090
00091
00092 if (m_size == m_collection.size()) {
00093
00094 m_collection.push_back(new ChartTranslationOptions(tpc, stackVec,
00095 range, score));
00096 } else {
00097
00098 *(m_collection[m_size]) = ChartTranslationOptions(tpc, stackVec,
00099 range, score);
00100 }
00101 ++m_size;
00102
00103
00104 if (!m_ruleLimit || m_size <= m_ruleLimit) {
00105 m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold;
00106 }
00107
00108
00109 if (m_ruleLimit && m_size == m_ruleLimit * 2) {
00110 NTH_ELEMENT4(m_collection.begin(),
00111 m_collection.begin() + m_ruleLimit - 1,
00112 m_collection.begin() + m_size,
00113 ChartTranslationOptionOrderer());
00114 m_scoreThreshold = m_collection[m_ruleLimit-1]->GetEstimateOfBestScore();
00115 m_size = m_ruleLimit;
00116 }
00117 }
00118
00119 void
00120 ChartTranslationOptionList::
00121 AddPhraseOOV(TargetPhrase &phrase,
00122 std::list<TargetPhraseCollection::shared_ptr > &waste_memory,
00123 const Range &range)
00124 {
00125 TargetPhraseCollection::shared_ptr tpc(new TargetPhraseCollection);
00126 tpc->Add(&phrase);
00127 waste_memory.push_back(tpc);
00128 StackVec empty;
00129 Add(*tpc, empty, range);
00130 }
00131
00132 void ChartTranslationOptionList::ApplyThreshold(float const threshold)
00133 {
00134 if (m_ruleLimit && m_size > m_ruleLimit) {
00135
00136
00137 assert(m_size < m_ruleLimit * 2);
00138
00139
00140 NTH_ELEMENT4(m_collection.begin(),
00141 m_collection.begin()+m_ruleLimit,
00142 m_collection.begin()+m_size,
00143 ChartTranslationOptionOrderer());
00144 m_size = m_ruleLimit;
00145 }
00146
00147
00148
00149 float scoreThreshold = -std::numeric_limits<float>::infinity();
00150
00151 CollType::const_iterator iter;
00152 for (iter = m_collection.begin(); iter != m_collection.begin()+m_size; ++iter) {
00153 const ChartTranslationOptions *transOpt = *iter;
00154 float score = transOpt->GetEstimateOfBestScore();
00155 scoreThreshold = (score > scoreThreshold) ? score : scoreThreshold;
00156 }
00157
00158 scoreThreshold += threshold;
00159
00160 CollType::iterator bound = std::partition(m_collection.begin(),
00161 m_collection.begin()+m_size,
00162 ScoreThresholdPred(scoreThreshold));
00163
00164 m_size = std::distance(m_collection.begin(), bound);
00165 }
00166
00167 float ChartTranslationOptionList::GetBestScore(const ChartCellLabel *chartCell) const
00168 {
00169 const HypoList *stack = chartCell->GetStack().cube;
00170 assert(stack);
00171 assert(!stack->empty());
00172 const ChartHypothesis &bestHypo = **(stack->begin());
00173 return bestHypo.GetFutureScore();
00174 }
00175
00176 void ChartTranslationOptionList::EvaluateWithSourceContext(const InputType &input, const InputPath &inputPath)
00177 {
00178
00179 CollType::iterator iter;
00180 for (iter = m_collection.begin(); iter != m_collection.begin() + m_size; ++iter) {
00181 ChartTranslationOptions &transOpts = **iter;
00182 transOpts.EvaluateWithSourceContext(input, inputPath);
00183 }
00184
00185
00186 size_t numDiscard = 0;
00187 for (size_t i = 0; i < m_size; ++i) {
00188 ChartTranslationOptions *transOpts = m_collection[i];
00189 if (transOpts->GetSize() == 0) {
00190
00191 ++numDiscard;
00192 } else if (numDiscard) {
00193 SwapTranslationOptions(i - numDiscard, i);
00194
00195 }
00196 }
00197
00198 size_t newSize = m_size - numDiscard;
00199 m_size = newSize;
00200 }
00201
00202 void ChartTranslationOptionList::SwapTranslationOptions(size_t a, size_t b)
00203 {
00204 ChartTranslationOptions *transOptsA = m_collection[a];
00205 ChartTranslationOptions *transOptsB = m_collection[b];
00206 m_collection[a] = transOptsB;
00207 m_collection[b] = transOptsA;
00208 }
00209
00210 std::ostream& operator<<(std::ostream &out, const ChartTranslationOptionList &obj)
00211 {
00212 for (size_t i = 0; i < obj.m_collection.size(); ++i) {
00213 const ChartTranslationOptions &transOpts = *obj.m_collection[i];
00214 out << transOpts << endl;
00215 }
00216 return out;
00217 }
00218
00219 }