00001 #include "ConstrainedDecoding.h"
00002 #include "moses/Hypothesis.h"
00003 #include "moses/Manager.h"
00004 #include "moses/ChartHypothesis.h"
00005 #include "moses/ChartManager.h"
00006 #include "moses/StaticData.h"
00007 #include "moses/InputFileStream.h"
00008 #include "moses/Util.h"
00009 #include "util/exception.hh"
00010
00011 using namespace std;
00012
00013 namespace Moses
00014 {
00015 ConstrainedDecodingState::ConstrainedDecodingState(const Hypothesis &hypo)
00016 {
00017 hypo.GetOutputPhrase(m_outputPhrase);
00018 }
00019
00020 ConstrainedDecodingState::ConstrainedDecodingState(const ChartHypothesis &hypo)
00021 {
00022 hypo.GetOutputPhrase(m_outputPhrase);
00023 }
00024
00025 size_t ConstrainedDecodingState::hash() const
00026 {
00027 size_t ret = hash_value(m_outputPhrase);
00028 return ret;
00029 }
00030
00031 bool ConstrainedDecodingState::operator==(const FFState& other) const
00032 {
00033 const ConstrainedDecodingState &otherFF = static_cast<const ConstrainedDecodingState&>(other);
00034 bool ret = m_outputPhrase == otherFF.m_outputPhrase;
00035 return ret;
00036 }
00037
00039 ConstrainedDecoding::ConstrainedDecoding(const std::string &line)
00040 :StatefulFeatureFunction(1, line)
00041 ,m_maxUnknowns(0)
00042 ,m_negate(false)
00043 ,m_soft(false)
00044 {
00045 m_tuneable = false;
00046 ReadParameters();
00047 }
00048
00049 void ConstrainedDecoding::Load(AllOptions::ptr const& opts)
00050 {
00051 m_options = opts;
00052 const StaticData &staticData = StaticData::Instance();
00053 bool addBeginEndWord
00054 = ((opts->search.algo == CYKPlus) || (opts->search.algo == ChartIncremental));
00055
00056 for(size_t i = 0; i < m_paths.size(); ++i) {
00057 InputFileStream constraintFile(m_paths[i]);
00058 std::string line;
00059 long sentenceID = opts->output.start_translation_id - 1 ;
00060 while (getline(constraintFile, line)) {
00061 vector<string> vecStr = Tokenize(line, "\t");
00062
00063 Phrase phrase(0);
00064 if (vecStr.size() == 1) {
00065 sentenceID++;
00066 phrase.CreateFromString(Output, opts->output.factor_order, vecStr[0], NULL);
00067 } else if (vecStr.size() == 2) {
00068 sentenceID = Scan<long>(vecStr[0]);
00069 phrase.CreateFromString(Output, opts->output.factor_order, vecStr[1], NULL);
00070 } else {
00071 UTIL_THROW(util::Exception, "Reference file not loaded");
00072 }
00073
00074 if (addBeginEndWord) {
00075 phrase.InitStartEndWord();
00076 }
00077 m_constraints[sentenceID].push_back(phrase);
00078 }
00079 }
00080 }
00081
00082 std::vector<float> ConstrainedDecoding::DefaultWeights() const
00083 {
00084 UTIL_THROW_IF2(m_numScoreComponents != 1,
00085 "ConstrainedDecoding must only have 1 score");
00086 vector<float> ret(1, 1);
00087 return ret;
00088 }
00089
00090 template <class H, class M>
00091 const std::vector<Phrase> *GetConstraint(const std::map<long,std::vector<Phrase> > &constraints, const H &hypo)
00092 {
00093 const M &mgr = hypo.GetManager();
00094 const InputType &input = mgr.GetSource();
00095 long id = input.GetTranslationId();
00096
00097 map<long,std::vector<Phrase> >::const_iterator iter;
00098 iter = constraints.find(id);
00099
00100 if (iter == constraints.end()) {
00101 UTIL_THROW(util::Exception, "Couldn't find reference " << id);
00102
00103 return NULL;
00104 } else {
00105 return &iter->second;
00106 }
00107 }
00108
00109 FFState* ConstrainedDecoding::EvaluateWhenApplied(
00110 const Hypothesis& hypo,
00111 const FFState* prev_state,
00112 ScoreComponentCollection* accumulator) const
00113 {
00114 const std::vector<Phrase> *ref = GetConstraint<Hypothesis, Manager>(m_constraints, hypo);
00115 assert(ref);
00116
00117 ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
00118 const Phrase &outputPhrase = ret->GetPhrase();
00119
00120 size_t searchPos = NOT_FOUND;
00121 size_t i = 0;
00122 size_t size = 0;
00123 while(searchPos == NOT_FOUND && i < ref->size()) {
00124 searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
00125 size = (*ref)[i].GetSize();
00126 i++;
00127 }
00128
00129 float score;
00130 if (hypo.IsSourceCompleted()) {
00131
00132 bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
00133 if (!m_negate) {
00134 score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00135 } else {
00136 score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00137 }
00138 } else if (m_negate) {
00139
00140 score = 0;
00141 } else {
00142 score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00143 }
00144
00145 accumulator->PlusEquals(this, score);
00146
00147 return ret;
00148 }
00149
00150 FFState* ConstrainedDecoding::EvaluateWhenApplied(
00151 const ChartHypothesis &hypo,
00152 int ,
00153 ScoreComponentCollection* accumulator) const
00154 {
00155 const std::vector<Phrase> *ref = GetConstraint<ChartHypothesis, ChartManager>(m_constraints, hypo);
00156 assert(ref);
00157
00158 const ChartManager &mgr = hypo.GetManager();
00159 const Sentence &source = static_cast<const Sentence&>(mgr.GetSource());
00160
00161 ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
00162 const Phrase &outputPhrase = ret->GetPhrase();
00163
00164 size_t searchPos = NOT_FOUND;
00165 size_t i = 0;
00166 size_t size = 0;
00167 while(searchPos == NOT_FOUND && i < ref->size()) {
00168 searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
00169 size = (*ref)[i].GetSize();
00170 i++;
00171 }
00172
00173 float score;
00174 if (hypo.GetCurrSourceRange().GetStartPos() == 0 &&
00175 hypo.GetCurrSourceRange().GetEndPos() == source.GetSize() - 1) {
00176
00177 bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
00178
00179 if (!m_negate) {
00180 score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00181 } else {
00182 score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00183 }
00184 } else if (m_negate) {
00185
00186 score = 0;
00187 } else {
00188 score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00189 }
00190
00191 accumulator->PlusEquals(this, score);
00192
00193 return ret;
00194 }
00195
00196 void ConstrainedDecoding::SetParameter(const std::string& key, const std::string& value)
00197 {
00198 if (key == "path") {
00199 m_paths = Tokenize(value, ",");
00200 } else if (key == "max-unknowns") {
00201 m_maxUnknowns = Scan<int>(value);
00202 } else if (key == "negate") {
00203 m_negate = Scan<bool>(value);
00204 } else if (key == "soft") {
00205 m_soft = Scan<bool>(value);
00206 } else {
00207 StatefulFeatureFunction::SetParameter(key, value);
00208 }
00209 }
00210
00211 }
00212