00001 #pragma once
00002
00003 #include "moses/FF/FFState.h"
00004 #include "moses/ChartHypothesis.h"
00005 #include "moses/ChartManager.h"
00006
00007 namespace Moses
00008 {
00009
00010 class LanguageModelChartState : public FFState
00011 {
00012 private:
00013 float m_prefixScore;
00014 FFState* m_lmRightContext;
00015
00016 Phrase m_contextPrefix, m_contextSuffix;
00017
00018 size_t m_numTargetTerminals;
00019
00020 const ChartHypothesis &m_hypo;
00021
00026 size_t CalcPrefix(const ChartHypothesis &hypo, int featureID, Phrase &ret, size_t size) const {
00027 const TargetPhrase &target = hypo.GetCurrTargetPhrase();
00028 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00029 target.GetAlignNonTerm().GetNonTermIndexMap();
00030
00031
00032 for (size_t pos = 0; pos < target.GetSize(); ++pos) {
00033 const Word &word = target.GetWord(pos);
00034
00035
00036 if (word.IsNonTerminal()) {
00037 size_t nonTermInd = nonTermIndexMap[pos];
00038 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
00039 size = static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID))->CalcPrefix(*prevHypo, featureID, ret, size);
00040 }
00041
00042 else {
00043 ret.AddWord(target.GetWord(pos));
00044 size--;
00045 }
00046
00047
00048 if (size==0)
00049 break;
00050 }
00051
00052 return size;
00053 }
00054
00060 size_t CalcSuffix(const ChartHypothesis &hypo, int featureID, Phrase &ret, size_t size) const {
00061 UTIL_THROW_IF2(m_contextPrefix.GetSize() > m_numTargetTerminals, "Error");
00062
00063
00064
00065 if (m_contextPrefix.GetSize() == m_numTargetTerminals) {
00066 size_t maxCount = std::min(m_contextPrefix.GetSize(), size);
00067 size_t pos= m_contextPrefix.GetSize() - 1;
00068
00069 for (size_t ind = 0; ind < maxCount; ++ind) {
00070 const Word &word = m_contextPrefix.GetWord(pos);
00071 ret.PrependWord(word);
00072 --pos;
00073 }
00074
00075 size -= maxCount;
00076 return size;
00077 }
00078
00079 else {
00080 const TargetPhrase& target = hypo.GetCurrTargetPhrase();
00081 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00082 target.GetAlignNonTerm().GetNonTermIndexMap();
00083 for (int pos = (int) target.GetSize() - 1; pos >= 0 ; --pos) {
00084 const Word &word = target.GetWord(pos);
00085
00086 if (word.IsNonTerminal()) {
00087 size_t nonTermInd = nonTermIndexMap[pos];
00088 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
00089 size = static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID))->CalcSuffix(*prevHypo, featureID, ret, size);
00090 } else {
00091 ret.PrependWord(hypo.GetCurrTargetPhrase().GetWord(pos));
00092 size--;
00093 }
00094
00095 if (size==0)
00096 break;
00097 }
00098
00099 return size;
00100 }
00101 }
00102
00103
00104 public:
00105 LanguageModelChartState(const ChartHypothesis &hypo, int featureID, size_t order)
00106 :m_lmRightContext(NULL)
00107 ,m_contextPrefix(order - 1)
00108 ,m_contextSuffix( order - 1)
00109 ,m_hypo(hypo) {
00110 m_numTargetTerminals = hypo.GetCurrTargetPhrase().GetNumTerminals();
00111
00112 for (std::vector<const ChartHypothesis*>::const_iterator i = hypo.GetPrevHypos().begin(); i != hypo.GetPrevHypos().end(); ++i) {
00113
00114 m_numTargetTerminals += static_cast<const LanguageModelChartState*>((*i)->GetFFState(featureID))->GetNumTargetTerminals();
00115 }
00116
00117 CalcPrefix(hypo, featureID, m_contextPrefix, order - 1);
00118 CalcSuffix(hypo, featureID, m_contextSuffix, order - 1);
00119 }
00120
00121 ~LanguageModelChartState() {
00122 delete m_lmRightContext;
00123 }
00124
00125 void Set(float prefixScore, FFState *rightState) {
00126 m_prefixScore = prefixScore;
00127 m_lmRightContext = rightState;
00128 }
00129
00130 float GetPrefixScore() const {
00131 return m_prefixScore;
00132 }
00133 FFState* GetRightContext() const {
00134 return m_lmRightContext;
00135 }
00136
00137 size_t GetNumTargetTerminals() const {
00138 return m_numTargetTerminals;
00139 }
00140
00141 const Phrase &GetPrefix() const {
00142 return m_contextPrefix;
00143 }
00144 const Phrase &GetSuffix() const {
00145 return m_contextSuffix;
00146 }
00147
00148 size_t hash() const {
00149 size_t ret;
00150
00151
00152 ret = m_hypo.GetCurrSourceRange().GetStartPos() > 0;
00153 if (m_hypo.GetCurrSourceRange().GetStartPos() > 0) {
00154 size_t hash = hash_value(GetPrefix());
00155 boost::hash_combine(ret, hash);
00156 }
00157
00158
00159 size_t inputSize = m_hypo.GetManager().GetSource().GetSize();
00160 boost::hash_combine(ret, m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1);
00161 if (m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1) {
00162 size_t hash = m_lmRightContext->hash();
00163 boost::hash_combine(ret, hash);
00164 }
00165
00166 return ret;
00167 }
00168 virtual bool operator==(const FFState& o) const {
00169 const LanguageModelChartState &other =
00170 static_cast<const LanguageModelChartState &>( o );
00171
00172
00173 if (m_hypo.GetCurrSourceRange().GetStartPos() > 0) {
00174 bool ret = GetPrefix() == other.GetPrefix();
00175 if (ret == false)
00176 return false;
00177 }
00178
00179
00180 size_t inputSize = m_hypo.GetManager().GetSource().GetSize();
00181 if (m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1) {
00182 bool ret = (*other.GetRightContext()) == (*m_lmRightContext);
00183 return ret;
00184 }
00185 return true;
00186 }
00187
00188 };
00189
00190 }
00191