00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "tree_scorer.h"
00021
00022 #include <cassert>
00023 #include <sstream>
00024
00025 namespace MosesTraining
00026 {
00027 namespace Syntax
00028 {
00029 namespace PCFG
00030 {
00031
00032 TreeScorer::TreeScorer(const Pcfg &pcfg, const Vocabulary &non_term_vocab)
00033 : pcfg_(pcfg)
00034 , non_term_vocab_(non_term_vocab)
00035 {
00036 }
00037
00038 bool TreeScorer::Score(SyntaxTree &root)
00039 {
00040 scores_.clear();
00041 ZeroScores(root);
00042 if (!CalcScores(root)) {
00043 return false;
00044 }
00045 SetAttributes(root);
00046 return true;
00047 }
00048
00049 bool TreeScorer::CalcScores(SyntaxTree &root)
00050 {
00051 if (root.IsLeaf() || root.children()[0]->IsLeaf()) {
00052 return true;
00053 }
00054
00055 const std::vector<SyntaxTree *> &children = root.children();
00056
00057 double log_prob = 0.0;
00058
00059 std::vector<std::size_t> key;
00060 key.reserve(children.size()+1);
00061 key.push_back(non_term_vocab_.Lookup(root.value().label));
00062
00063 for (std::vector<SyntaxTree *>::const_iterator p(children.begin());
00064 p != children.end(); ++p) {
00065 SyntaxTree *child = *p;
00066 assert(!child->IsLeaf());
00067 key.push_back(non_term_vocab_.Lookup(child->value().label));
00068 if (!CalcScores(*child)) {
00069 return false;
00070 }
00071 if (!child->children()[0]->IsLeaf()) {
00072 log_prob += scores_[child];
00073 }
00074 }
00075 double rule_score;
00076 bool found = pcfg_.Lookup(key, rule_score);
00077 if (!found) {
00078 return false;
00079 }
00080 log_prob += rule_score;
00081 scores_[&root] = log_prob;
00082 return true;
00083 }
00084
00085 void TreeScorer::SetAttributes(SyntaxTree &root)
00086 {
00087
00088 if (root.IsLeaf()) {
00089 return;
00090 }
00091
00092 if (root.children()[0]->IsLeaf()) {
00093 return;
00094 }
00095 double score = scores_[&root];
00096 if (score != 0.0) {
00097 std::ostringstream out;
00098 out << score;
00099 root.value().attributes["pcfg"] = out.str();
00100 }
00101 for (std::vector<SyntaxTree *>::const_iterator p(root.children().begin());
00102 p != root.children().end(); ++p) {
00103 SetAttributes(**p);
00104 }
00105 }
00106
00107 void TreeScorer::ZeroScores(SyntaxTree &root)
00108 {
00109 scores_[&root] = 0.0f;
00110 const std::vector<SyntaxTree *> &children = root.children();
00111 for (std::vector<SyntaxTree *>::const_iterator p(children.begin());
00112 p != children.end(); ++p) {
00113 ZeroScores(**p);
00114 }
00115 }
00116
00117 }
00118 }
00119 }