00001
00002 #include "DistortionScoreProducer.h"
00003 #include "FFState.h"
00004 #include "moses/InputPath.h"
00005 #include "moses/Range.h"
00006 #include "moses/StaticData.h"
00007 #include "moses/Hypothesis.h"
00008 #include "moses/Manager.h"
00009 #include "moses/FactorCollection.h"
00010 #include <cmath>
00011
00012 using namespace std;
00013
00014 namespace Moses
00015 {
00016 struct DistortionState : public FFState {
00017 Range range;
00018 int first_gap;
00019 bool inSubordinateConjunction;
00020 DistortionState(const Range& wr, int fg, bool subord=false) : range(wr), first_gap(fg), inSubordinateConjunction(subord) {}
00021
00022 size_t hash() const {
00023 return range.GetEndPos();
00024 }
00025 virtual bool operator==(const FFState& other) const {
00026 const DistortionState& o =
00027 static_cast<const DistortionState&>(other);
00028 return ( (range.GetEndPos() == o.range.GetEndPos()) && (inSubordinateConjunction == o.inSubordinateConjunction) );
00029 }
00030
00031 };
00032
00033 std::vector<const DistortionScoreProducer*> DistortionScoreProducer::s_staticColl;
00034
00035 DistortionScoreProducer::DistortionScoreProducer(const std::string &line)
00036 : StatefulFeatureFunction(1, line)
00037 , m_useSparse(false)
00038 , m_sparseDistance(false)
00039 , m_sparseSubordinate(false)
00040 {
00041 s_staticColl.push_back(this);
00042 ReadParameters();
00043 }
00044
00045 void DistortionScoreProducer::SetParameter(const std::string& key, const std::string& value)
00046 {
00047 if (key == "sparse") {
00048 m_useSparse = Scan<bool>(value);
00049 } else if (key == "sparse-distance") {
00050 m_sparseDistance = Scan<bool>(value);
00051 } else if (key == "sparse-input-factor") {
00052 m_sparseFactorTypeSource = Scan<FactorType>(value);
00053 } else if (key == "sparse-output-factor") {
00054 m_sparseFactorTypeTarget = Scan<FactorType>(value);
00055 } else if (key == "sparse-subordinate") {
00056 std::string subordinateConjunctionTag = Scan<std::string>(value);
00057 FactorCollection &factorCollection = FactorCollection::Instance();
00058 m_subordinateConjunctionTagFactor = factorCollection.AddFactor(subordinateConjunctionTag,false);
00059 m_sparseSubordinate = true;
00060 } else if (key == "sparse-subordinate-output-factor") {
00061 m_sparseFactorTypeTargetSubordinate = Scan<FactorType>(value);
00062 } else {
00063 StatefulFeatureFunction::SetParameter(key, value);
00064 }
00065 }
00066
00067 const FFState* DistortionScoreProducer::EmptyHypothesisState(const InputType &input) const
00068 {
00069
00070 size_t start = NOT_FOUND;
00071 size_t end = NOT_FOUND;
00072 if (input.m_frontSpanCoveredLength > 0) {
00073
00074 start = 0;
00075 end = input.m_frontSpanCoveredLength -1;
00076 }
00077 return new DistortionState(
00078 Range(start, end),
00079 NOT_FOUND);
00080 }
00081
00082 float
00083 DistortionScoreProducer::
00084 CalculateDistortionScore(const Hypothesis& hypo,
00085 const Range &prev, const Range &curr, const int FirstGap)
00086 {
00087
00088 if(!hypo.GetManager().options()->reordering.use_early_distortion_cost) {
00089 return - (float) hypo.GetInput().ComputeDistortionDistance(prev, curr);
00090 }
00091
00092
00093
00094
00095
00096
00097
00098
00099 int prefixEndPos = (int)FirstGap-1;
00100 if((int)FirstGap==-1)
00101 prefixEndPos = -1;
00102
00103
00104 if ((int) curr.GetStartPos() == prefixEndPos+1) {
00105 IFVERBOSE(4) std::cerr<< "MQ07disto:case1" << std::endl;
00106 return 0;
00107 }
00108
00109
00110 if ((int) curr.GetEndPos() < (int) prev.GetEndPos()) {
00111 IFVERBOSE(4) std::cerr<< "MQ07disto:case2" << std::endl;
00112 return (float) -2*(int)curr.GetNumWordsCovered();
00113 }
00114
00115
00116 if ((int) prev.GetEndPos() <= prefixEndPos) {
00117 IFVERBOSE(4) std::cerr<< "MQ07disto:case3" << std::endl;
00118 int z = (int)curr.GetStartPos()-prefixEndPos - 1;
00119 return (float) -2*(z + (int)curr.GetNumWordsCovered());
00120 }
00121
00122
00123 IFVERBOSE(4) std::cerr<< "MQ07disto:case4" << std::endl;
00124 return (float) -2*((int)curr.GetNumWordsBetween(prev) + (int)curr.GetNumWordsCovered());
00125
00126 }
00127
00128
00129 FFState* DistortionScoreProducer::EvaluateWhenApplied(
00130 const Hypothesis& hypo,
00131 const FFState* prev_state,
00132 ScoreComponentCollection* out) const
00133 {
00134 const DistortionState* prev = static_cast<const DistortionState*>(prev_state);
00135 bool subordinateConjunction = prev->inSubordinateConjunction;
00136
00137 if (m_useSparse) {
00138 int jumpFromPos = prev->range.GetEndPos()+1;
00139 int jumpToPos = hypo.GetCurrSourceWordsRange().GetStartPos();
00140 size_t distance = std::abs( jumpFromPos - jumpToPos );
00141
00142 const Sentence& sentence = static_cast<const Sentence&>(hypo.GetInput());
00143
00144 StringPiece jumpFromSourceFactorPrev;
00145 StringPiece jumpFromSourceFactor;
00146 StringPiece jumpToSourceFactor;
00147 if (jumpFromPos < (int)sentence.GetSize()) {
00148 jumpFromSourceFactor = sentence.GetWord(jumpFromPos).GetFactor(m_sparseFactorTypeSource)->GetString();
00149 } else {
00150 jumpFromSourceFactor = "</s>";
00151 }
00152 if (jumpFromPos > 0) {
00153 jumpFromSourceFactorPrev = sentence.GetWord(jumpFromPos-1).GetFactor(m_sparseFactorTypeSource)->GetString();
00154 } else {
00155 jumpFromSourceFactorPrev = "<s>";
00156 }
00157 jumpToSourceFactor = sentence.GetWord(jumpToPos).GetFactor(m_sparseFactorTypeSource)->GetString();
00158
00159 const TargetPhrase& currTargetPhrase = hypo.GetCurrTargetPhrase();
00160 StringPiece jumpToTargetFactor = currTargetPhrase.GetWord(0).GetFactor(m_sparseFactorTypeTarget)->GetString();
00161
00162 util::StringStream featureName;
00163
00164
00165 featureName = util::StringStream();
00166 featureName << m_description << "_";
00167 if ( jumpToPos > jumpFromPos ) {
00168 featureName << "R";
00169 } else if ( jumpToPos < jumpFromPos ) {
00170 featureName << "L";
00171 } else {
00172 featureName << "M";
00173 }
00174 if (m_sparseDistance) {
00175 featureName << distance;
00176 }
00177 featureName << "_SFS_" << jumpFromSourceFactor;
00178 if (m_sparseSubordinate && subordinateConjunction) {
00179 featureName << "_SUBORD";
00180 }
00181 out->SparsePlusEquals(featureName.str(), 1);
00182
00183
00184 featureName = util::StringStream();
00185 featureName << m_description << "_";
00186 if ( jumpToPos > jumpFromPos ) {
00187 featureName << "R";
00188 } else if ( jumpToPos < jumpFromPos ) {
00189 featureName << "L";
00190 } else {
00191 featureName << "M";
00192 }
00193 if (m_sparseDistance) {
00194 featureName << distance;
00195 }
00196 featureName << "_SFP_" << jumpFromSourceFactorPrev;
00197 if (m_sparseSubordinate && subordinateConjunction) {
00198 featureName << "_SUBORD";
00199 }
00200 out->SparsePlusEquals(featureName.str(), 1);
00201
00202
00203 featureName = util::StringStream();
00204 featureName << m_description << "_";
00205 if ( jumpToPos > jumpFromPos ) {
00206 featureName << "R";
00207 } else if ( jumpToPos < jumpFromPos ) {
00208 featureName << "L";
00209 } else {
00210 featureName << "M";
00211 }
00212 if (m_sparseDistance) {
00213 featureName << distance;
00214 }
00215 featureName << "_SFE_" << jumpToSourceFactor;
00216 if (m_sparseSubordinate && subordinateConjunction) {
00217 featureName << "_SUBORD";
00218 }
00219 out->SparsePlusEquals(featureName.str(), 1);
00220
00221
00222 featureName = util::StringStream();
00223 featureName << m_description << "_";
00224 if ( jumpToPos > jumpFromPos ) {
00225 featureName << "R";
00226 } else if ( jumpToPos < jumpFromPos ) {
00227 featureName << "L";
00228 } else {
00229 featureName << "M";
00230 }
00231 if (m_sparseDistance) {
00232 featureName << distance;
00233 }
00234 featureName << "_TFE_" << jumpToTargetFactor;
00235 if (m_sparseSubordinate && subordinateConjunction) {
00236 featureName << "_SUBORD";
00237 }
00238 out->SparsePlusEquals(featureName.str(), 1);
00239
00240
00241 featureName = util::StringStream();
00242 featureName << m_description << "_";
00243 if ( jumpToPos > jumpFromPos ) {
00244 featureName << "R";
00245 } else if ( jumpToPos < jumpFromPos ) {
00246 featureName << "L";
00247 } else {
00248 featureName << "M";
00249 }
00250 if (m_sparseDistance) {
00251 featureName << distance;
00252 }
00253 size_t relativeSourceSentencePosBin = std::floor( 5 * (float)jumpFromPos / (sentence.GetSize()+1) );
00254 featureName << "_P_" << relativeSourceSentencePosBin;
00255 if (m_sparseSubordinate && subordinateConjunction) {
00256 featureName << "_SUBORD";
00257 }
00258 out->SparsePlusEquals(featureName.str(), 1);
00259
00260
00261 featureName = util::StringStream();
00262 featureName << m_description << "_";
00263 if ( jumpToPos > jumpFromPos ) {
00264 featureName << "R";
00265 } else if ( jumpToPos < jumpFromPos ) {
00266 featureName << "L";
00267 } else {
00268 featureName << "M";
00269 }
00270 if (m_sparseDistance) {
00271 featureName << distance;
00272 }
00273 size_t sourceSentenceLengthBin = 3;
00274 if (sentence.GetSize() < 15) {
00275 sourceSentenceLengthBin = 0;
00276 } else if (sentence.GetSize() < 23) {
00277 sourceSentenceLengthBin = 1;
00278 } else if (sentence.GetSize() < 33) {
00279 sourceSentenceLengthBin = 2;
00280 }
00281 featureName << "_SL_" << sourceSentenceLengthBin;
00282 if (m_sparseSubordinate && subordinateConjunction) {
00283 featureName << "_SUBORD";
00284 }
00285 out->SparsePlusEquals(featureName.str(), 1);
00286
00287 if (m_sparseSubordinate) {
00288 for (size_t posT=0; posT<currTargetPhrase.GetSize(); ++posT) {
00289 const Word &wordT = currTargetPhrase.GetWord(posT);
00290 if (wordT[m_sparseFactorTypeTargetSubordinate] == m_subordinateConjunctionTagFactor) {
00291 subordinateConjunction = true;
00292 } else if (wordT[m_sparseFactorTypeTargetSubordinate]->GetString()[0] == 'V') {
00293 subordinateConjunction = false;
00294 }
00295 };
00296 }
00297 }
00298
00299 const float distortionScore = CalculateDistortionScore(
00300 hypo,
00301 prev->range,
00302 hypo.GetCurrSourceWordsRange(),
00303 prev->first_gap);
00304 out->PlusEquals(this, distortionScore);
00305
00306 DistortionState* state = new DistortionState(
00307 hypo.GetCurrSourceWordsRange(),
00308 hypo.GetWordsBitmap().GetFirstGapPos(),
00309 subordinateConjunction);
00310
00311 return state;
00312 }
00313
00314
00315 }
00316