00001 #ifndef moses_TargetNgramFeature_h
00002 #define moses_TargetNgramFeature_h
00003
00004 #include <string>
00005 #include <map>
00006 #include <boost/unordered_set.hpp>
00007
00008 #include "StatefulFeatureFunction.h"
00009 #include "moses/FF/FFState.h"
00010 #include "moses/Word.h"
00011 #include "moses/FactorCollection.h"
00012 #include "moses/LM/SingleFactor.h"
00013 #include "moses/ChartHypothesis.h"
00014 #include "moses/ChartManager.h"
00015 #include "util/string_stream.hh"
00016
00017 namespace Moses
00018 {
00019
00020 class TargetNgramState : public FFState
00021 {
00022 public:
00023 TargetNgramState() {}
00024
00025 TargetNgramState(const std::vector<Word> &words): m_words(words) {}
00026 const std::vector<Word> GetWords() const {
00027 return m_words;
00028 }
00029
00030 size_t hash() const;
00031 virtual bool operator==(const FFState& other) const;
00032
00033 private:
00034 std::vector<Word> m_words;
00035 };
00036
00037 class TargetNgramChartState : public FFState
00038 {
00039 private:
00040 Phrase m_contextPrefix, m_contextSuffix;
00041
00042 size_t m_numTargetTerminals;
00043
00044 size_t m_startPos, m_endPos, m_inputSize;
00045
00050 size_t CalcPrefix(const ChartHypothesis &hypo, const int featureId, Phrase &ret, size_t size) const {
00051 const TargetPhrase &target = hypo.GetCurrTargetPhrase();
00052 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00053 target.GetAlignNonTerm().GetNonTermIndexMap();
00054
00055
00056 for (size_t pos = 0; pos < target.GetSize(); ++pos) {
00057 const Word &word = target.GetWord(pos);
00058
00059
00060 if (word.IsNonTerminal()) {
00061 size_t nonTermInd = nonTermIndexMap[pos];
00062 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
00063 size = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId))->CalcPrefix(*prevHypo, featureId, ret, size);
00064
00065
00066 }
00067
00068 else {
00069 ret.AddWord(word);
00070 size--;
00071 }
00072
00073
00074 if (size==0)
00075 break;
00076 }
00077
00078 return size;
00079 }
00080
00086 size_t CalcSuffix(const ChartHypothesis &hypo, int featureId, Phrase &ret, size_t size) const {
00087 size_t prefixSize = m_contextPrefix.GetSize();
00088 assert(prefixSize <= m_numTargetTerminals);
00089
00090
00091
00092 if (prefixSize == m_numTargetTerminals) {
00093 size_t maxCount = std::min(prefixSize, size);
00094 size_t pos= prefixSize - 1;
00095
00096 for (size_t ind = 0; ind < maxCount; ++ind) {
00097 const Word &word = m_contextPrefix.GetWord(pos);
00098 ret.PrependWord(word);
00099 --pos;
00100 }
00101
00102 size -= maxCount;
00103 return size;
00104 }
00105
00106 else {
00107 const TargetPhrase targetPhrase = hypo.GetCurrTargetPhrase();
00108 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00109 targetPhrase.GetAlignTerm().GetNonTermIndexMap();
00110 for (int pos = (int) targetPhrase.GetSize() - 1; pos >= 0 ; --pos) {
00111 const Word &word = targetPhrase.GetWord(pos);
00112
00113 if (word.IsNonTerminal()) {
00114 size_t nonTermInd = nonTermIndexMap[pos];
00115 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
00116 size = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId))->CalcSuffix(*prevHypo, featureId, ret, size);
00117 } else {
00118 ret.PrependWord(word);
00119 size--;
00120 }
00121
00122 if (size==0)
00123 break;
00124 }
00125
00126 return size;
00127 }
00128 }
00129
00130 public:
00131 TargetNgramChartState(const ChartHypothesis &hypo, int featureId, size_t order)
00132 :m_contextPrefix(order - 1),
00133 m_contextSuffix(order - 1) {
00134 m_numTargetTerminals = hypo.GetCurrTargetPhrase().GetNumTerminals();
00135 const Range range = hypo.GetCurrSourceRange();
00136 m_startPos = range.GetStartPos();
00137 m_endPos = range.GetEndPos();
00138 m_inputSize = hypo.GetManager().GetSource().GetSize();
00139
00140 const std::vector<const ChartHypothesis*> prevHypos = hypo.GetPrevHypos();
00141 for (std::vector<const ChartHypothesis*>::const_iterator i = prevHypos.begin(); i != prevHypos.end(); ++i) {
00142
00143 m_numTargetTerminals += static_cast<const TargetNgramChartState*>((*i)->GetFFState(featureId))->GetNumTargetTerminals();
00144 }
00145
00146 CalcPrefix(hypo, featureId, m_contextPrefix, order - 1);
00147 CalcSuffix(hypo, featureId, m_contextSuffix, order - 1);
00148 }
00149
00150 size_t GetNumTargetTerminals() const {
00151 return m_numTargetTerminals;
00152 }
00153
00154 const Phrase &GetPrefix() const {
00155 return m_contextPrefix;
00156 }
00157 const Phrase &GetSuffix() const {
00158 return m_contextSuffix;
00159 }
00160
00161 size_t hash() const {
00162
00163 size_t ret;
00164
00165 ret = m_startPos;
00166 boost::hash_combine(ret, m_endPos);
00167 boost::hash_combine(ret, m_inputSize);
00168
00169
00170 if (m_startPos > 0) {
00171 boost::hash_combine(ret, hash_value(GetPrefix()));
00172 }
00173
00174 if (m_endPos < m_inputSize - 1) {
00175 boost::hash_combine(ret, hash_value(GetSuffix()));
00176 }
00177
00178 return ret;
00179 }
00180 virtual bool operator==(const FFState& o) const {
00181 const TargetNgramChartState &other =
00182 static_cast<const TargetNgramChartState &>( o );
00183
00184
00185 if (m_startPos > 0) {
00186 if (GetPrefix() != other.GetPrefix())
00187 return false;
00188 }
00189
00190 if (m_endPos < m_inputSize - 1) {
00191 if (GetSuffix() != other.GetSuffix())
00192 return false;
00193 }
00194 return true;
00195 }
00196
00197 };
00198
00201 class TargetNgramFeature : public StatefulFeatureFunction
00202 {
00203 public:
00204 TargetNgramFeature(const std::string &line);
00205
00206 void Load(AllOptions::ptr const& opts);
00207
00208 bool IsUseable(const FactorMask &mask) const;
00209
00210 virtual const FFState* EmptyHypothesisState(const InputType &input) const;
00211
00212 virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state,
00213 ScoreComponentCollection* accumulator) const;
00214
00215 virtual FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureId,
00216 ScoreComponentCollection* accumulator) const;
00217
00218 void SetParameter(const std::string& key, const std::string& value);
00219
00220 private:
00221 FactorType m_factorType;
00222 Word m_bos;
00223 boost::unordered_set<std::string> m_vocab;
00224 size_t m_n;
00225 bool m_lower_ngrams;
00226 std::string m_file;
00227
00228 std::string m_baseName;
00229
00230 void appendNgram(const Word& word, bool& skip, util::StringStream& ngram) const;
00231 void MakePrefixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator,
00232 size_t numberOfStartPos = 1, size_t offset = 0) const;
00233 void MakeSuffixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator,
00234 size_t numberOfEndPos = 1, size_t offset = 0) const;
00235 };
00236
00237 }
00238
00239 #endif // moses_TargetNgramFeature_h