00001 #include "ConstrainedDecoding.h"
00002 #include "moses/Hypothesis.h"
00003 #include "moses/Manager.h"
00004 #include "moses/ChartHypothesis.h"
00005 #include "moses/ChartManager.h"
00006 #include "moses/StaticData.h"
00007 #include "moses/InputFileStream.h"
00008 #include "moses/Util.h"
00009 #include "util/exception.hh"
00010 
00011 using namespace std;
00012 
00013 namespace Moses
00014 {
00015 ConstrainedDecodingState::ConstrainedDecodingState(const Hypothesis &hypo)
00016 {
00017   hypo.GetOutputPhrase(m_outputPhrase);
00018 }
00019 
00020 ConstrainedDecodingState::ConstrainedDecodingState(const ChartHypothesis &hypo)
00021 {
00022   hypo.GetOutputPhrase(m_outputPhrase);
00023 }
00024 
00025 size_t ConstrainedDecodingState::hash() const
00026 {
00027   size_t ret = hash_value(m_outputPhrase);
00028   return ret;
00029 }
00030 
00031 bool ConstrainedDecodingState::operator==(const FFState& other) const
00032 {
00033   const ConstrainedDecodingState &otherFF = static_cast<const ConstrainedDecodingState&>(other);
00034   bool ret = m_outputPhrase == otherFF.m_outputPhrase;
00035   return ret;
00036 }
00037 
00039 ConstrainedDecoding::ConstrainedDecoding(const std::string &line)
00040   :StatefulFeatureFunction(1, line)
00041   ,m_maxUnknowns(0)
00042   ,m_negate(false)
00043   ,m_soft(false)
00044 {
00045   m_tuneable = false;
00046   ReadParameters();
00047 }
00048 
00049 void ConstrainedDecoding::Load(AllOptions::ptr const& opts)
00050 {
00051   m_options = opts;
00052   const StaticData &staticData = StaticData::Instance();
00053   bool addBeginEndWord
00054   = ((opts->search.algo == CYKPlus) || (opts->search.algo == ChartIncremental));
00055 
00056   for(size_t i = 0; i < m_paths.size(); ++i) {
00057     InputFileStream constraintFile(m_paths[i]);
00058     std::string line;
00059     long sentenceID = opts->output.start_translation_id - 1 ;
00060     while (getline(constraintFile, line)) {
00061       vector<string> vecStr = Tokenize(line, "\t");
00062 
00063       Phrase phrase(0);
00064       if (vecStr.size() == 1) {
00065         sentenceID++;
00066         phrase.CreateFromString(Output, opts->output.factor_order, vecStr[0], NULL);
00067       } else if (vecStr.size() == 2) {
00068         sentenceID = Scan<long>(vecStr[0]);
00069         phrase.CreateFromString(Output, opts->output.factor_order, vecStr[1], NULL);
00070       } else {
00071         UTIL_THROW(util::Exception, "Reference file not loaded");
00072       }
00073 
00074       if (addBeginEndWord) {
00075         phrase.InitStartEndWord();
00076       }
00077       m_constraints[sentenceID].push_back(phrase);
00078     }
00079   }
00080 }
00081 
00082 std::vector<float> ConstrainedDecoding::DefaultWeights() const
00083 {
00084   UTIL_THROW_IF2(m_numScoreComponents != 1,
00085                  "ConstrainedDecoding must only have 1 score");
00086   vector<float> ret(1, 1);
00087   return ret;
00088 }
00089 
00090 template <class H, class M>
00091 const std::vector<Phrase> *GetConstraint(const std::map<long,std::vector<Phrase> > &constraints, const H &hypo)
00092 {
00093   const M &mgr = hypo.GetManager();
00094   const InputType &input = mgr.GetSource();
00095   long id = input.GetTranslationId();
00096 
00097   map<long,std::vector<Phrase> >::const_iterator iter;
00098   iter = constraints.find(id);
00099 
00100   if (iter == constraints.end()) {
00101     UTIL_THROW(util::Exception, "Couldn't find reference " << id);
00102 
00103     return NULL;
00104   } else {
00105     return &iter->second;
00106   }
00107 }
00108 
00109 FFState* ConstrainedDecoding::EvaluateWhenApplied(
00110   const Hypothesis& hypo,
00111   const FFState* prev_state,
00112   ScoreComponentCollection* accumulator) const
00113 {
00114   const std::vector<Phrase> *ref = GetConstraint<Hypothesis, Manager>(m_constraints, hypo);
00115   assert(ref);
00116 
00117   ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
00118   const Phrase &outputPhrase = ret->GetPhrase();
00119 
00120   size_t searchPos = NOT_FOUND;
00121   size_t i = 0;
00122   size_t size = 0;
00123   while(searchPos == NOT_FOUND && i < ref->size()) {
00124     searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
00125     size = (*ref)[i].GetSize();
00126     i++;
00127   }
00128 
00129   float score;
00130   if (hypo.IsSourceCompleted()) {
00131     
00132     bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
00133     if (!m_negate) {
00134       score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00135     } else {
00136       score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00137     }
00138   } else if (m_negate) {
00139     
00140     score = 0;
00141   } else {
00142     score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00143   }
00144 
00145   accumulator->PlusEquals(this, score);
00146 
00147   return ret;
00148 }
00149 
00150 FFState* ConstrainedDecoding::EvaluateWhenApplied(
00151   const ChartHypothesis &hypo,
00152   int ,
00153   ScoreComponentCollection* accumulator) const
00154 {
00155   const std::vector<Phrase> *ref = GetConstraint<ChartHypothesis, ChartManager>(m_constraints, hypo);
00156   assert(ref);
00157 
00158   const ChartManager &mgr = hypo.GetManager();
00159   const Sentence &source = static_cast<const Sentence&>(mgr.GetSource());
00160 
00161   ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
00162   const Phrase &outputPhrase = ret->GetPhrase();
00163 
00164   size_t searchPos = NOT_FOUND;
00165   size_t i = 0;
00166   size_t size = 0;
00167   while(searchPos == NOT_FOUND && i < ref->size()) {
00168     searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
00169     size = (*ref)[i].GetSize();
00170     i++;
00171   }
00172 
00173   float score;
00174   if (hypo.GetCurrSourceRange().GetStartPos() == 0 &&
00175       hypo.GetCurrSourceRange().GetEndPos() == source.GetSize() - 1) {
00176     
00177     bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
00178 
00179     if (!m_negate) {
00180       score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00181     } else {
00182       score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00183     }
00184   } else if (m_negate) {
00185     
00186     score = 0;
00187   } else {
00188     score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
00189   }
00190 
00191   accumulator->PlusEquals(this, score);
00192 
00193   return ret;
00194 }
00195 
00196 void ConstrainedDecoding::SetParameter(const std::string& key, const std::string& value)
00197 {
00198   if (key == "path") {
00199     m_paths = Tokenize(value, ",");
00200   } else if (key == "max-unknowns") {
00201     m_maxUnknowns = Scan<int>(value);
00202   } else if (key == "negate") {
00203     m_negate = Scan<bool>(value);
00204   } else if (key == "soft") {
00205     m_soft = Scan<bool>(value);
00206   } else {
00207     StatefulFeatureFunction::SetParameter(key, value);
00208   }
00209 }
00210 
00211 }
00212