00001 #include "search/vertex.hh"
00002
00003 #include "search/context.hh"
00004
00005 #include <boost/unordered_map.hpp>
00006
00007 #include <algorithm>
00008 #include <functional>
00009 #include <cassert>
00010
00011 namespace search {
00012
00013 namespace {
00014
00015 const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
00016
00017 class DivideLeft {
00018 public:
00019 explicit DivideLeft(unsigned char index)
00020 : index_(index) {}
00021
00022 uint64_t operator()(const lm::ngram::ChartState &state) const {
00023 return (index_ < state.left.length) ?
00024 state.left.pointers[index_] :
00025 (kCompleteAdd - state.left.full);
00026 }
00027
00028 private:
00029 unsigned char index_;
00030 };
00031
00032 class DivideRight {
00033 public:
00034 explicit DivideRight(unsigned char index)
00035 : index_(index) {}
00036
00037 uint64_t operator()(const lm::ngram::ChartState &state) const {
00038 return (index_ < state.right.length) ?
00039 static_cast<uint64_t>(state.right.words[index_]) :
00040 (kCompleteAdd - state.left.full);
00041 }
00042
00043 private:
00044 unsigned char index_;
00045 };
00046
00047 template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) {
00048
00049 typedef boost::unordered_map<uint64_t, std::size_t> Lookup;
00050 Lookup lookup;
00051 for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) {
00052 uint64_t key = divider(i->state);
00053 std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size())));
00054 if (res.second) {
00055 extend.resize(extend.size() + 1);
00056 extend.back().AppendHypothesis(*i);
00057 } else {
00058 extend[res.first->second].AppendHypothesis(*i);
00059 }
00060 }
00061
00062 }
00063
00064 lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) {
00065 return right.words[index];
00066 }
00067
00068 uint64_t Identify(const lm::ngram::Left &left, unsigned char index) {
00069 return left.pointers[index];
00070 }
00071
00072 template <class Side> class DetermineSame {
00073 public:
00074 DetermineSame(const Side &side, unsigned char guaranteed)
00075 : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {}
00076
00077 void Consider(const Side &other) {
00078 if (shared_ != other.length) {
00079 complete_ = false;
00080 if (shared_ > other.length)
00081 shared_ = other.length;
00082 }
00083 for (unsigned char i = guaranteed_; i < shared_; ++i) {
00084 if (Identify(side_, i) != Identify(other, i)) {
00085 shared_ = i;
00086 complete_ = false;
00087 return;
00088 }
00089 }
00090 }
00091
00092 unsigned char Shared() const { return shared_; }
00093
00094 bool Complete() const { return complete_; }
00095
00096 private:
00097 const Side &side_;
00098 unsigned char guaranteed_, shared_;
00099 bool complete_;
00100 };
00101
00102
00103
00104 const unsigned char kPolicyAlternate = 0;
00105
00106 const unsigned char kPolicyOneLeft = 1;
00107
00108 const unsigned char kPolicyOneRight = 2;
00109
00110
00111
00112 }
00113
00114 namespace {
00115 struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> {
00116 bool operator()(const HypoState &first, const HypoState &second) const {
00117 return first.score > second.score;
00118 }
00119 };
00120 }
00121
00122 void VertexNode::FinishRoot() {
00123 std::sort(hypos_.begin(), hypos_.end(), GreaterByScore());
00124 extend_.clear();
00125
00126 state_.left.full = false;
00127 state_.left.length = 0;
00128 state_.right.length = 0;
00129 right_full_ = false;
00130 niceness_ = 0;
00131 policy_ = kPolicyAlternate;
00132 if (hypos_.size() == 1) {
00133 extend_.resize(1);
00134 extend_.front().AppendHypothesis(hypos_.front());
00135 extend_.front().FinishedAppending(0, 0);
00136 }
00137 if (hypos_.empty()) {
00138 bound_ = -INFINITY;
00139 } else {
00140 bound_ = hypos_.front().score;
00141 }
00142 }
00143
00144 void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) {
00145 assert(!hypos_.empty());
00146 assert(extend_.empty());
00147 bound_ = hypos_.front().score;
00148 state_ = hypos_.front().state;
00149 bool all_full = state_.left.full;
00150 bool all_non_full = !state_.left.full;
00151 DetermineSame<lm::ngram::Left> left(state_.left, common_left);
00152 DetermineSame<lm::ngram::Right> right(state_.right, common_right);
00153 for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) {
00154 all_full &= i->state.left.full;
00155 all_non_full &= !i->state.left.full;
00156 left.Consider(i->state.left);
00157 right.Consider(i->state.right);
00158 }
00159 state_.left.full = all_full && left.Complete();
00160 right_full_ = all_full && right.Complete();
00161 state_.left.length = left.Shared();
00162 state_.right.length = right.Shared();
00163
00164 if (!all_full && !all_non_full) {
00165 policy_ = kPolicyAlternate;
00166 } else if (left.Complete()) {
00167 policy_ = kPolicyOneRight;
00168 } else if (right.Complete()) {
00169 policy_ = kPolicyOneLeft;
00170 } else {
00171 policy_ = kPolicyAlternate;
00172 }
00173 niceness_ = state_.left.length + state_.right.length;
00174 }
00175
00176 void VertexNode::BuildExtend() {
00177
00178 if (!extend_.empty()) return;
00179
00180 if (hypos_.size() <= 1) return;
00181 bool left_branch = true;
00182 switch (policy_) {
00183 case kPolicyAlternate:
00184 left_branch = (state_.left.length <= state_.right.length);
00185 break;
00186 case kPolicyOneLeft:
00187 left_branch = true;
00188 break;
00189 case kPolicyOneRight:
00190 left_branch = false;
00191 break;
00192 }
00193 if (left_branch) {
00194 Split(DivideLeft(state_.left.length), hypos_, extend_);
00195 } else {
00196 Split(DivideRight(state_.right.length), hypos_, extend_);
00197 }
00198 for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
00199
00200 i->FinishedAppending(state_.left.length, state_.right.length);
00201 }
00202 }
00203
00204 }