00001 
00002 #include "DistortionScoreProducer.h"
00003 #include "FFState.h"
00004 #include "moses/InputPath.h"
00005 #include "moses/Range.h"
00006 #include "moses/StaticData.h"
00007 #include "moses/Hypothesis.h"
00008 #include "moses/Manager.h"
00009 #include "moses/FactorCollection.h"
00010 #include <cmath>
00011 
00012 using namespace std;
00013 
00014 namespace Moses
00015 {
00016 struct DistortionState : public FFState {
00017   Range range;
00018   int first_gap;
00019   bool inSubordinateConjunction;
00020   DistortionState(const Range& wr, int fg, bool subord=false) : range(wr), first_gap(fg), inSubordinateConjunction(subord) {}
00021 
00022   size_t hash() const {
00023     return range.GetEndPos();
00024   }
00025   virtual bool operator==(const FFState& other) const {
00026     const DistortionState& o =
00027       static_cast<const DistortionState&>(other);
00028     return ( (range.GetEndPos() == o.range.GetEndPos()) && (inSubordinateConjunction == o.inSubordinateConjunction) );
00029   }
00030 
00031 };
00032 
00033 std::vector<const DistortionScoreProducer*> DistortionScoreProducer::s_staticColl;
00034 
00035 DistortionScoreProducer::DistortionScoreProducer(const std::string &line)
00036   : StatefulFeatureFunction(1, line)
00037   , m_useSparse(false)
00038   , m_sparseDistance(false)
00039   , m_sparseSubordinate(false)
00040 {
00041   s_staticColl.push_back(this);
00042   ReadParameters();
00043 }
00044 
00045 void DistortionScoreProducer::SetParameter(const std::string& key, const std::string& value)
00046 {
00047   if (key == "sparse") {
00048     m_useSparse = Scan<bool>(value);
00049   } else if (key == "sparse-distance") {
00050     m_sparseDistance = Scan<bool>(value);
00051   } else if (key == "sparse-input-factor") {
00052     m_sparseFactorTypeSource = Scan<FactorType>(value);
00053   } else if (key == "sparse-output-factor") {
00054     m_sparseFactorTypeTarget = Scan<FactorType>(value);
00055   } else if (key == "sparse-subordinate") {
00056     std::string subordinateConjunctionTag = Scan<std::string>(value);
00057     FactorCollection &factorCollection = FactorCollection::Instance();
00058     m_subordinateConjunctionTagFactor = factorCollection.AddFactor(subordinateConjunctionTag,false);
00059     m_sparseSubordinate = true;
00060   } else if (key == "sparse-subordinate-output-factor") {
00061     m_sparseFactorTypeTargetSubordinate = Scan<FactorType>(value);
00062   } else {
00063     StatefulFeatureFunction::SetParameter(key, value);
00064   }
00065 }
00066 
00067 const FFState* DistortionScoreProducer::EmptyHypothesisState(const InputType &input) const
00068 {
00069   
00070   size_t start = NOT_FOUND;
00071   size_t end = NOT_FOUND;
00072   if (input.m_frontSpanCoveredLength > 0) {
00073     
00074     start = 0;
00075     end = input.m_frontSpanCoveredLength -1;
00076   }
00077   return new DistortionState(
00078            Range(start, end),
00079            NOT_FOUND);
00080 }
00081 
00082 float
00083 DistortionScoreProducer::
00084 CalculateDistortionScore(const Hypothesis& hypo,
00085                          const Range &prev, const Range &curr, const int FirstGap)
00086 {
00087   
00088   if(!hypo.GetManager().options()->reordering.use_early_distortion_cost) {
00089     return - (float) hypo.GetInput().ComputeDistortionDistance(prev, curr);
00090   } 
00091 
00092   
00093 
00094 
00095 
00096 
00097 
00098 
00099   int prefixEndPos = (int)FirstGap-1;
00100   if((int)FirstGap==-1)
00101     prefixEndPos = -1;
00102 
00103   
00104   if ((int) curr.GetStartPos() == prefixEndPos+1) {
00105     IFVERBOSE(4) std::cerr<< "MQ07disto:case1" << std::endl;
00106     return 0;
00107   }
00108 
00109   
00110   if ((int) curr.GetEndPos() < (int) prev.GetEndPos()) {
00111     IFVERBOSE(4) std::cerr<< "MQ07disto:case2" << std::endl;
00112     return (float) -2*(int)curr.GetNumWordsCovered();
00113   }
00114 
00115   
00116   if ((int) prev.GetEndPos() <= prefixEndPos) {
00117     IFVERBOSE(4) std::cerr<< "MQ07disto:case3" << std::endl;
00118     int z = (int)curr.GetStartPos()-prefixEndPos - 1;
00119     return (float) -2*(z + (int)curr.GetNumWordsCovered());
00120   }
00121 
00122   
00123   IFVERBOSE(4) std::cerr<< "MQ07disto:case4" << std::endl;
00124   return (float) -2*((int)curr.GetNumWordsBetween(prev) + (int)curr.GetNumWordsCovered());
00125 
00126 }
00127 
00128 
00129 FFState* DistortionScoreProducer::EvaluateWhenApplied(
00130   const Hypothesis& hypo,
00131   const FFState* prev_state,
00132   ScoreComponentCollection* out) const
00133 {
00134   const DistortionState* prev = static_cast<const DistortionState*>(prev_state);
00135   bool subordinateConjunction = prev->inSubordinateConjunction;
00136 
00137   if (m_useSparse) {
00138     int jumpFromPos = prev->range.GetEndPos()+1;
00139     int jumpToPos = hypo.GetCurrSourceWordsRange().GetStartPos();
00140     size_t distance = std::abs( jumpFromPos - jumpToPos );
00141 
00142     const Sentence& sentence = static_cast<const Sentence&>(hypo.GetInput());
00143 
00144     StringPiece jumpFromSourceFactorPrev;
00145     StringPiece jumpFromSourceFactor;
00146     StringPiece jumpToSourceFactor;
00147     if (jumpFromPos < (int)sentence.GetSize()) {
00148       jumpFromSourceFactor = sentence.GetWord(jumpFromPos).GetFactor(m_sparseFactorTypeSource)->GetString();
00149     } else {
00150       jumpFromSourceFactor = "</s>";
00151     }
00152     if (jumpFromPos > 0) {
00153       jumpFromSourceFactorPrev = sentence.GetWord(jumpFromPos-1).GetFactor(m_sparseFactorTypeSource)->GetString();
00154     } else {
00155       jumpFromSourceFactorPrev = "<s>";
00156     }
00157     jumpToSourceFactor = sentence.GetWord(jumpToPos).GetFactor(m_sparseFactorTypeSource)->GetString();
00158 
00159     const TargetPhrase& currTargetPhrase = hypo.GetCurrTargetPhrase();
00160     StringPiece jumpToTargetFactor = currTargetPhrase.GetWord(0).GetFactor(m_sparseFactorTypeTarget)->GetString();
00161 
00162     util::StringStream featureName;
00163 
00164     
00165     featureName = util::StringStream();
00166     featureName << m_description << "_";
00167     if ( jumpToPos > jumpFromPos ) {
00168       featureName << "R";
00169     } else if ( jumpToPos < jumpFromPos ) {
00170       featureName << "L";
00171     } else {
00172       featureName << "M";
00173     }
00174     if (m_sparseDistance) {
00175       featureName << distance;
00176     }
00177     featureName << "_SFS_" << jumpFromSourceFactor;
00178     if (m_sparseSubordinate && subordinateConjunction) {
00179       featureName << "_SUBORD";
00180     }
00181     out->SparsePlusEquals(featureName.str(), 1);
00182 
00183     
00184     featureName = util::StringStream();
00185     featureName << m_description << "_";
00186     if ( jumpToPos > jumpFromPos ) {
00187       featureName << "R";
00188     } else if ( jumpToPos < jumpFromPos ) {
00189       featureName << "L";
00190     } else {
00191       featureName << "M";
00192     }
00193     if (m_sparseDistance) {
00194       featureName << distance;
00195     }
00196     featureName << "_SFP_" << jumpFromSourceFactorPrev;
00197     if (m_sparseSubordinate && subordinateConjunction) {
00198       featureName << "_SUBORD";
00199     }
00200     out->SparsePlusEquals(featureName.str(), 1);
00201 
00202     
00203     featureName = util::StringStream();
00204     featureName << m_description << "_";
00205     if ( jumpToPos > jumpFromPos ) {
00206       featureName << "R";
00207     } else if ( jumpToPos < jumpFromPos ) {
00208       featureName << "L";
00209     } else {
00210       featureName << "M";
00211     }
00212     if (m_sparseDistance) {
00213       featureName << distance;
00214     }
00215     featureName << "_SFE_" << jumpToSourceFactor;
00216     if (m_sparseSubordinate && subordinateConjunction) {
00217       featureName << "_SUBORD";
00218     }
00219     out->SparsePlusEquals(featureName.str(), 1);
00220 
00221     
00222     featureName = util::StringStream();
00223     featureName << m_description << "_";
00224     if ( jumpToPos > jumpFromPos ) {
00225       featureName << "R";
00226     } else if ( jumpToPos < jumpFromPos ) {
00227       featureName << "L";
00228     } else {
00229       featureName << "M";
00230     }
00231     if (m_sparseDistance) {
00232       featureName << distance;
00233     }
00234     featureName << "_TFE_" << jumpToTargetFactor;
00235     if (m_sparseSubordinate && subordinateConjunction) {
00236       featureName << "_SUBORD";
00237     }
00238     out->SparsePlusEquals(featureName.str(), 1);
00239 
00240     
00241     featureName = util::StringStream();
00242     featureName << m_description << "_";
00243     if ( jumpToPos > jumpFromPos ) {
00244       featureName << "R";
00245     } else if ( jumpToPos < jumpFromPos ) {
00246       featureName << "L";
00247     } else {
00248       featureName << "M";
00249     }
00250     if (m_sparseDistance) {
00251       featureName << distance;
00252     }
00253     size_t relativeSourceSentencePosBin = std::floor( 5 * (float)jumpFromPos / (sentence.GetSize()+1) );
00254     featureName << "_P_" << relativeSourceSentencePosBin;
00255     if (m_sparseSubordinate && subordinateConjunction) {
00256       featureName << "_SUBORD";
00257     }
00258     out->SparsePlusEquals(featureName.str(), 1);
00259 
00260     
00261     featureName = util::StringStream();
00262     featureName << m_description << "_";
00263     if ( jumpToPos > jumpFromPos ) {
00264       featureName << "R";
00265     } else if ( jumpToPos < jumpFromPos ) {
00266       featureName << "L";
00267     } else {
00268       featureName << "M";
00269     }
00270     if (m_sparseDistance) {
00271       featureName << distance;
00272     }
00273     size_t sourceSentenceLengthBin = 3;
00274     if (sentence.GetSize() < 15) {
00275       sourceSentenceLengthBin = 0;
00276     } else if (sentence.GetSize() < 23) {
00277       sourceSentenceLengthBin = 1;
00278     } else if (sentence.GetSize() < 33) {
00279       sourceSentenceLengthBin = 2;
00280     }
00281     featureName << "_SL_" << sourceSentenceLengthBin;
00282     if (m_sparseSubordinate && subordinateConjunction) {
00283       featureName << "_SUBORD";
00284     }
00285     out->SparsePlusEquals(featureName.str(), 1);
00286 
00287     if (m_sparseSubordinate) {
00288       for (size_t posT=0; posT<currTargetPhrase.GetSize(); ++posT) {
00289         const Word &wordT = currTargetPhrase.GetWord(posT);
00290         if (wordT[m_sparseFactorTypeTargetSubordinate] == m_subordinateConjunctionTagFactor) {
00291           subordinateConjunction = true;
00292         } else if (wordT[m_sparseFactorTypeTargetSubordinate]->GetString()[0] == 'V') {
00293           subordinateConjunction = false;
00294         }
00295       };
00296     }
00297   }
00298 
00299   const float distortionScore = CalculateDistortionScore(
00300                                   hypo,
00301                                   prev->range,
00302                                   hypo.GetCurrSourceWordsRange(),
00303                                   prev->first_gap);
00304   out->PlusEquals(this, distortionScore);
00305 
00306   DistortionState* state = new DistortionState(
00307     hypo.GetCurrSourceWordsRange(),
00308     hypo.GetWordsBitmap().GetFirstGapPos(),
00309     subordinateConjunction);
00310 
00311   return state;
00312 }
00313 
00314 
00315 }
00316