00001 #pragma once
00002
00003 #include <string>
00004 #include <map>
00005 #include <limits>
00006 #include <vector>
00007
00008 #include <boost/unordered_map.hpp>
00009 #include <boost/functional/hash.hpp>
00010
00011 #include "moses/FF/StatefulFeatureFunction.h"
00012 #include "moses/PP/CountsPhraseProperty.h"
00013 #include "moses/TranslationOptionList.h"
00014 #include "moses/TranslationOption.h"
00015 #include "moses/Util.h"
00016 #include "moses/TypeDef.h"
00017 #include "moses/StaticData.h"
00018 #include "moses/Phrase.h"
00019 #include "moses/AlignmentInfo.h"
00020 #include "moses/Word.h"
00021 #include "moses/FactorCollection.h"
00022
00023 #include "Normalizer.h"
00024 #include "Classifier.h"
00025 #include "VWFeatureBase.h"
00026 #include "TabbedSentence.h"
00027 #include "ThreadLocalByFeatureStorage.h"
00028 #include "TrainingLoss.h"
00029 #include "VWTargetSentence.h"
00030
00031
00032
00033
00034
00035
00036
00037 namespace Moses
00038 {
00039
00040
00041 const std::string VW_DUMMY_LABEL = "1111";
00042
00043
00044 typedef ThreadLocalByFeatureStorage<Discriminative::Classifier, Discriminative::ClassifierFactory &> TLSClassifier;
00045
00046
00047 typedef ThreadLocalByFeatureStorage<VWTargetSentence> TLSTargetSentence;
00048
00049
00050 typedef boost::unordered_map<size_t, Discriminative::FeatureVector> FeatureVectorMap;
00051
00052
00053 typedef ThreadLocalByFeatureStorage<FeatureVectorMap> TLSFeatureVectorMap;
00054
00055
00056 typedef boost::unordered_map<size_t, float> FloatHashMap;
00057
00058
00059 typedef ThreadLocalByFeatureStorage<FloatHashMap> TLSFloatHashMap;
00060
00061
00062 typedef ThreadLocalByFeatureStorage<boost::unordered_map<size_t, FloatHashMap> > TLSStateExtensions;
00063
00064
00065
00066
00067 class VW : public StatefulFeatureFunction, public TLSTargetSentence
00068 {
00069 public:
00070 VW(const std::string &line);
00071
00072 virtual ~VW();
00073
00074 bool IsUseable(const FactorMask &mask) const {
00075 return true;
00076 }
00077
00078 void EvaluateInIsolation(const Phrase &source
00079 , const TargetPhrase &targetPhrase
00080 , ScoreComponentCollection &scoreBreakdown
00081 , ScoreComponentCollection &estimatedFutureScore) const {
00082 }
00083
00084 void EvaluateWithSourceContext(const InputType &input
00085 , const InputPath &inputPath
00086 , const TargetPhrase &targetPhrase
00087 , const StackVec *stackVec
00088 , ScoreComponentCollection &scoreBreakdown
00089 , ScoreComponentCollection *estimatedFutureScore = NULL) const {
00090 }
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107 virtual void EvaluateTranslationOptionListWithSourceContext(const InputType &input
00108 , const TranslationOptionList &translationOptionList) const;
00109
00110
00111
00112
00113
00114
00115
00116
00117 virtual FFState* EvaluateWhenApplied(
00118 const Hypothesis& curHypo,
00119 const FFState* prevState,
00120 ScoreComponentCollection* accumulator) const;
00121
00122 virtual FFState* EvaluateWhenApplied(
00123 const ChartHypothesis&,
00124 int,
00125 ScoreComponentCollection* accumulator) const {
00126 throw new std::logic_error("hiearchical/syntax not supported");
00127 }
00128
00129
00130 const FFState* EmptyHypothesisState(const InputType &input) const;
00131
00132 void SetParameter(const std::string& key, const std::string& value);
00133
00134
00135
00136 virtual void InitializeForInput(ttasksptr const& ttask);
00137
00138 private:
00139 inline std::string MakeTargetLabel(const TargetPhrase &targetPhrase) const {
00140 return VW_DUMMY_LABEL;
00141 }
00142
00143 inline size_t MakeCacheKey(const FFState *prevState, size_t spanStart, size_t spanEnd) const {
00144 size_t key = 0;
00145 boost::hash_combine(key, prevState);
00146 boost::hash_combine(key, spanStart);
00147 boost::hash_combine(key, spanEnd);
00148 return key;
00149 }
00150
00151
00152
00153
00154 const AlignmentInfo *TransformAlignmentInfo(const Hypothesis &curHypo, size_t contextSize) const;
00155
00156
00157
00158 AlignmentInfo TransformAlignmentInfo(const AlignmentInfo &alignInfo, size_t contextSize, int currentStart) const;
00159
00160
00161
00162
00163 std::pair<bool, int> IsCorrectTranslationOption(const TranslationOption &topt) const;
00164
00165
00166
00167 std::vector<bool> LeaveOneOut(const TranslationOptionList &topts, const std::vector<bool> &correct) const;
00168
00169 bool m_train;
00170 std::string m_modelPath;
00171 std::string m_vwOptions;
00172
00173
00174 Word m_sentenceStartWord;
00175
00176
00177 TrainingLoss *m_trainingLoss = NULL;
00178
00179
00180 std::string m_leaveOneOut;
00181
00182
00183 Discriminative::Normalizer *m_normalizer = NULL;
00184
00185
00186 TLSClassifier *m_tlsClassifier;
00187
00188
00189 TLSFloatHashMap *m_tlsFutureScores;
00190 TLSStateExtensions *m_tlsComputedStateExtensions;
00191 TLSFeatureVectorMap *m_tlsTranslationOptionFeatures, *m_tlsTargetContextFeatures;
00192 };
00193
00194 }