00001 #pragma once
00002
00003 #include <set>
00004 #include <cmath>
00005 #include <string>
00006
00007 #include "moses/Util.h"
00008 #include "moses/StaticData.h"
00009 #include "moses/Phrase.h"
00010
00011 namespace Moses
00012 {
00013
00017 class TrainingLoss
00018 {
00019 public:
00020 virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const = 0;
00021 };
00022
00026 class TrainingLossBasic : public TrainingLoss
00027 {
00028 public:
00029 virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
00030 return isCorrect ? 0.0 : 1.0;
00031 }
00032 };
00033
00037 class TrainingLossBLEU : public TrainingLoss
00038 {
00039 public:
00040 virtual float operator()(const TargetPhrase &candidate, const TargetPhrase &correct, bool isCorrect) const {
00041 std::multiset<std::string> refNgrams;
00042 float precision = 1.0;
00043
00044 for (size_t size = 1; size <= BLEU_N; size++) {
00045 for (int pos = 0; pos <= (int)correct.GetSize() - (int)size; pos++) {
00046 refNgrams.insert(MakeNGram(correct, pos, pos + size));
00047 }
00048
00049 int confirmed = 1;
00050 int total = 1;
00051 for (int pos = 0; pos <= (int)candidate.GetSize() - (int)size; pos++) {
00052 total++;
00053 std::string ngram = MakeNGram(candidate, pos, pos + size);
00054 std::multiset<std::string>::iterator it;
00055 if ((it = refNgrams.find(ngram)) != refNgrams.end()) {
00056 confirmed++;
00057 refNgrams.erase(it);
00058 }
00059 }
00060 precision *= (float)confirmed / total;
00061 }
00062
00063 int c = candidate.GetSize();
00064 int r = correct.GetSize();
00065
00066 float brevityPenalty = c < r ? exp((float)(1.0 - r) / c) : 1.0;
00067
00068 return 1.0 - brevityPenalty * pow(precision, (float)1.0 / BLEU_N);
00069 }
00070
00071 private:
00072 std::string MakeNGram(const TargetPhrase &phrase, size_t start, size_t end) const {
00073 std::vector<std::string> words;
00074 while (start != end) {
00075 words.push_back(phrase.GetWord(start).GetString(StaticData::Instance().options()->output.factor_order, false));
00076 start++;
00077 }
00078 return Join(" ", words);
00079 }
00080
00081 static const size_t BLEU_N = 2;
00082 };
00083
00084 }
00085