00001 #include "search/nbest.hh"
00002
00003 #include "util/pool.hh"
00004 #include "moses/Util.h"
00005
00006 #include <algorithm>
00007 #include <functional>
00008 #include <queue>
00009 #include <cassert>
00010 #include <cmath>
00011
00012 namespace search {
00013
00014 NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) {
00015 assert(!partials.empty());
00016 std::vector<PartialEdge>::iterator end;
00017 if (partials.size() > keep) {
00018 end = partials.begin() + keep;
00019 NTH_ELEMENT4(partials.begin(), end, partials.end(), std::greater<PartialEdge>());
00020 } else {
00021 end = partials.end();
00022 }
00023 for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) {
00024 queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i));
00025 }
00026 }
00027
00028 Score NBestList::TopAfterConstructor() const {
00029 assert(revealed_.empty());
00030 return queue_.top().GetScore();
00031 }
00032
00033 const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) {
00034 while (revealed_.size() < n && !queue_.empty()) {
00035 MoveTop(pool);
00036 }
00037 return revealed_;
00038 }
00039
00040 Score NBestList::Visit(util::Pool &pool, std::size_t index) {
00041 if (index + 1 < revealed_.size())
00042 return revealed_[index + 1].GetScore() - revealed_[index].GetScore();
00043 if (queue_.empty())
00044 return -INFINITY;
00045 if (index + 1 == revealed_.size())
00046 return queue_.top().GetScore() - revealed_[index].GetScore();
00047 assert(index == revealed_.size());
00048
00049 MoveTop(pool);
00050
00051 if (queue_.empty()) return -INFINITY;
00052 return queue_.top().GetScore() - revealed_[index].GetScore();
00053 }
00054
00055 Applied NBestList::Get(util::Pool &pool, std::size_t index) {
00056 assert(index <= revealed_.size());
00057 if (index == revealed_.size()) MoveTop(pool);
00058 return revealed_[index];
00059 }
00060
00061 void NBestList::MoveTop(util::Pool &pool) {
00062 assert(!queue_.empty());
00063 QueueEntry entry(queue_.top());
00064 queue_.pop();
00065 RevealedRef *const children_begin = entry.Children();
00066 RevealedRef *const children_end = children_begin + entry.GetArity();
00067 Score basis = entry.GetScore();
00068 for (RevealedRef *child = children_begin; child != children_end; ++child) {
00069 Score change = child->in_->Visit(pool, child->index_);
00070 if (change != -INFINITY) {
00071 assert(change < 0.001);
00072 QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote(), entry.GetRange());
00073 std::copy(children_begin, child, new_entry.Children());
00074 RevealedRef *update = new_entry.Children() + (child - children_begin);
00075 update->in_ = child->in_;
00076 update->index_ = child->index_ + 1;
00077 std::copy(child + 1, children_end, update + 1);
00078 queue_.push(new_entry);
00079 }
00080
00081 if (child->index_) break;
00082 }
00083
00084
00085 void *overwrite = entry.Children();
00086 for (unsigned int i = 0; i < entry.GetArity(); ++i) {
00087 RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i));
00088 *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_);
00089 }
00090 revealed_.push_back(Applied(entry.Base()));
00091 }
00092
00093 NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) {
00094 assert(!partials.empty());
00095 NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep);
00096 return NBestComplete(
00097 list,
00098 partials.front().CompletedState(),
00099 list->TopAfterConstructor());
00100 }
00101
00102 const std::vector<Applied> &NBest::Extract(History history) {
00103 return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size);
00104 }
00105
00106 }