00001 #include "HwcmScorer.h"
00002 
00003 #include <fstream>
00004 
00005 #include "ScoreStats.h"
00006 #include "Util.h"
00007 
00008 #include "util/tokenize_piece.hh"
00009 
00010 
00011 
00012 
00013 
00014 using namespace std;
00015 
00016 namespace MosesTuning
00017 {
00018 
00019 
00020 HwcmScorer::HwcmScorer(const string& config)
00021   : StatisticsBasedScorer("HWCM",config) {}
00022 
00023 HwcmScorer::~HwcmScorer() {}
00024 
00025 void HwcmScorer::setReferenceFiles(const vector<string>& referenceFiles)
00026 {
00027   
00028   if (referenceFiles.size() != 1) {
00029     throw runtime_error("HWCM only supports a single reference");
00030   }
00031   m_ref_trees.clear();
00032   m_ref_hwc.clear();
00033   ifstream in((referenceFiles[0] + ".trees").c_str());
00034   if (!in) {
00035     throw runtime_error("Unable to open " + referenceFiles[0] + ".trees");
00036   }
00037   string line;
00038   while (getline(in,line)) {
00039     line = this->preprocessSentence(line);
00040     TreePointer tree (boost::make_shared<InternalTree>(line));
00041     m_ref_trees.push_back(tree);
00042     vector<map<string, int> > hwc (kHwcmOrder);
00043     vector<string> history(kHwcmOrder);
00044     extractHeadWordChain(tree, history, hwc);
00045     m_ref_hwc.push_back(hwc);
00046     vector<int> totals(kHwcmOrder);
00047     for (size_t i = 0; i < kHwcmOrder; i++) {
00048       for (map<string, int>::const_iterator it = m_ref_hwc.back()[i].begin(); it != m_ref_hwc.back()[i].end(); it++) {
00049         totals[i] += it->second;
00050       }
00051     }
00052     m_ref_lengths.push_back(totals);
00053   }
00054   TRACE_ERR(endl);
00055 
00056 }
00057 
00058 void HwcmScorer::extractHeadWordChain(TreePointer tree, vector<string> & history, vector<map<string, int> > & hwc)
00059 {
00060 
00061   if (tree->GetLength() > 0) {
00062     string head = getHead(tree);
00063 
00064     if (head.empty()) {
00065       for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
00066         extractHeadWordChain(*it, history, hwc);
00067       }
00068     } else {
00069       vector<string> new_history(kHwcmOrder);
00070       new_history[0] = head;
00071       hwc[0][head]++;
00072       for (size_t hist_idx = 0; hist_idx < kHwcmOrder-1; hist_idx++) {
00073         if (!history[hist_idx].empty()) {
00074           string chain = history[hist_idx] + " " + head;
00075           hwc[hist_idx+1][chain]++;
00076           if (hist_idx+2 < kHwcmOrder) {
00077             new_history[hist_idx+1] = chain;
00078           }
00079         }
00080       }
00081       for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
00082         extractHeadWordChain(*it, new_history, hwc);
00083       }
00084     }
00085   }
00086 }
00087 
00088 string HwcmScorer::getHead(TreePointer tree)
00089 {
00090   
00091   
00092   for (std::vector<TreePointer>::const_iterator it = tree->GetChildren().begin(); it != tree->GetChildren().end(); ++it) {
00093     TreePointer child = *it;
00094 
00095     if (child->GetLength() == 1 && child->GetChildren()[0]->IsTerminal()) {
00096       return child->GetChildren()[0]->GetLabel();
00097     }
00098   }
00099   return "";
00100 
00101 }
00102 
00103 void HwcmScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
00104 {
00105   if (sid >= m_ref_trees.size()) {
00106     stringstream msg;
00107     msg << "Sentence id (" << sid << ") not found in reference set";
00108     throw runtime_error(msg.str());
00109   }
00110 
00111   string sentence = this->preprocessSentence(text);
00112 
00113   
00114   
00115   util::TokenIter<util::MultiCharacter> it(sentence, util::MultiCharacter("|||"));
00116   ++it;
00117   if (it) {
00118     sentence = it->as_string();
00119   }
00120 
00121   TreePointer tree (boost::make_shared<InternalTree>(sentence));
00122   vector<map<string, int> > hwc_test (kHwcmOrder);
00123   vector<string> history(kHwcmOrder);
00124   extractHeadWordChain(tree, history, hwc_test);
00125 
00126   ostringstream stats;
00127   for (size_t i = 0; i < kHwcmOrder; i++) {
00128     int correct = 0;
00129     int test_total = 0;
00130     for (map<string, int>::const_iterator it = hwc_test[i].begin(); it != hwc_test[i].end(); it++) {
00131       test_total += it->second;
00132       map<string, int>::const_iterator it2 = m_ref_hwc[sid][i].find(it->first);
00133       if (it2 != m_ref_hwc[sid][i].end()) {
00134         correct += std::min(it->second, it2->second);
00135       }
00136     }
00137     stats << correct << " " << test_total << " " << m_ref_lengths[sid][i] << " " ;
00138   }
00139 
00140   string stats_str = stats.str();
00141   entry.set(stats_str);
00142 }
00143 
00144 float HwcmScorer::calculateScore(const vector<ScoreStatsType>& comps) const
00145 {
00146   float precision = 0;
00147   float recall = 0;
00148   for (size_t i = 0; i < kHwcmOrder; i++) {
00149     float matches = comps[i*3];
00150     float test_total = comps[1+(i*3)];
00151     float ref_total = comps[2+(i*3)];
00152     if (test_total > 0) {
00153       precision += matches/test_total;
00154     }
00155     if (ref_total > 0) {
00156       recall += matches/ref_total;
00157     }
00158   }
00159 
00160   precision /= (float)kHwcmOrder;
00161   recall /= (float)kHwcmOrder;
00162   return (2*precision*recall)/(precision+recall); 
00163 }
00164 
00165 }