00001 #include "HwcmScorer.h"
00002
00003 #include <fstream>
00004
00005 #include "ScoreStats.h"
00006 #include "Util.h"
00007
00008 #include "util/tokenize_piece.hh"
00009
00010
00011
00012
00013
00014 using namespace std;
00015
00016 namespace MosesTuning
00017 {
00018
00019
00020 HwcmScorer::HwcmScorer(const string& config)
00021 : StatisticsBasedScorer("HWCM",config) {}
00022
00023 HwcmScorer::~HwcmScorer() {}
00024
00025 void HwcmScorer::setReferenceFiles(const vector<string>& referenceFiles)
00026 {
00027
00028 if (referenceFiles.size() != 1) {
00029 throw runtime_error("HWCM only supports a single reference");
00030 }
00031 m_ref_trees.clear();
00032 m_ref_hwc.clear();
00033 ifstream in((referenceFiles[0] + ".trees").c_str());
00034 if (!in) {
00035 throw runtime_error("Unable to open " + referenceFiles[0] + ".trees");
00036 }
00037 string line;
00038 while (getline(in,line)) {
00039 line = this->preprocessSentence(line);
00040 TreePointer tree (boost::make_shared<InternalTree>(line));
00041 m_ref_trees.push_back(tree);
00042 vector<map<string, int> > hwc (kHwcmOrder);
00043 vector<string> history(kHwcmOrder);
00044 extractHeadWordChain(tree, history, hwc);
00045 m_ref_hwc.push_back(hwc);
00046 vector<int> totals(kHwcmOrder);
00047 for (size_t i = 0; i < kHwcmOrder; i++) {
00048 for (map<string, int>::const_iterator it = m_ref_hwc.back()[i].begin(); it != m_ref_hwc.back()[i].end(); it++) {
00049 totals[i] += it->second;
00050 }
00051 }
00052 m_ref_lengths.push_back(totals);
00053 }
00054 TRACE_ERR(endl);
00055
00056 }
00057
00058 void HwcmScorer::extractHeadWordChain(TreePointer tree, vector<string> & history, vector<map<string, int> > & hwc)
00059 {
00060
00061 if (tree->GetLength() > 0) {
00062 string head = getHead(tree);
00063
00064 if (head.empty()) {
00065 for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
00066 extractHeadWordChain(*it, history, hwc);
00067 }
00068 } else {
00069 vector<string> new_history(kHwcmOrder);
00070 new_history[0] = head;
00071 hwc[0][head]++;
00072 for (size_t hist_idx = 0; hist_idx < kHwcmOrder-1; hist_idx++) {
00073 if (!history[hist_idx].empty()) {
00074 string chain = history[hist_idx] + " " + head;
00075 hwc[hist_idx+1][chain]++;
00076 if (hist_idx+2 < kHwcmOrder) {
00077 new_history[hist_idx+1] = chain;
00078 }
00079 }
00080 }
00081 for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
00082 extractHeadWordChain(*it, new_history, hwc);
00083 }
00084 }
00085 }
00086 }
00087
00088 string HwcmScorer::getHead(TreePointer tree)
00089 {
00090
00091
00092 for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
00093 TreePointer child = *it;
00094
00095 if (child->GetLength() == 1 && child->GetChildren()[0]->IsTerminal()) {
00096 return child->GetChildren()[0]->GetLabel();
00097 }
00098 }
00099 return "";
00100
00101 }
00102
00103 void HwcmScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
00104 {
00105 if (sid >= m_ref_trees.size()) {
00106 stringstream msg;
00107 msg << "Sentence id (" << sid << ") not found in reference set";
00108 throw runtime_error(msg.str());
00109 }
00110
00111 string sentence = this->preprocessSentence(text);
00112
00113
00114
00115 util::TokenIter<util::MultiCharacter> it(sentence, util::MultiCharacter("|||"));
00116 ++it;
00117 if (it) {
00118 sentence = it->as_string();
00119 }
00120
00121 TreePointer tree (boost::make_shared<InternalTree>(sentence));
00122 vector<map<string, int> > hwc_test (kHwcmOrder);
00123 vector<string> history(kHwcmOrder);
00124 extractHeadWordChain(tree, history, hwc_test);
00125
00126 ostringstream stats;
00127 for (size_t i = 0; i < kHwcmOrder; i++) {
00128 int correct = 0;
00129 int test_total = 0;
00130 for (map<string, int>::const_iterator it = hwc_test[i].begin(); it != hwc_test[i].end(); it++) {
00131 test_total += it->second;
00132 map<string, int>::const_iterator it2 = m_ref_hwc[sid][i].find(it->first);
00133 if (it2 != m_ref_hwc[sid][i].end()) {
00134 correct += std::min(it->second, it2->second);
00135 }
00136 }
00137 stats << correct << " " << test_total << " " << m_ref_lengths[sid][i] << " " ;
00138 }
00139
00140 string stats_str = stats.str();
00141 entry.set(stats_str);
00142 }
00143
00144 float HwcmScorer::calculateScore(const vector<ScoreStatsType>& comps) const
00145 {
00146 float precision = 0;
00147 float recall = 0;
00148 for (size_t i = 0; i < kHwcmOrder; i++) {
00149 float matches = comps[i*3];
00150 float test_total = comps[1+(i*3)];
00151 float ref_total = comps[2+(i*3)];
00152 if (test_total > 0) {
00153 precision += matches/test_total;
00154 }
00155 if (ref_total > 0) {
00156 recall += matches/ref_total;
00157 }
00158 }
00159
00160 precision /= (float)kHwcmOrder;
00161 recall /= (float)kHwcmOrder;
00162 return (2*precision*recall)/(precision+recall);
00163 }
00164
00165 }