00001 #ifndef LM_FILTER_THREAD_H
00002 #define LM_FILTER_THREAD_H
00003
00004 #include "util/thread_pool.hh"
00005
00006 #include <boost/utility/in_place_factory.hpp>
00007
00008 #include <deque>
00009 #include <stack>
00010
00011 namespace lm {
00012
00013 template <class OutputBuffer> class ThreadBatch {
00014 public:
00015 ThreadBatch() {}
00016
00017 void Reserve(size_t size) {
00018 input_.Reserve(size);
00019 output_.Reserve(size);
00020 }
00021
00022
00023 InputBuffer &Fill(uint64_t sequence) {
00024 sequence_ = sequence;
00025
00026
00027 input_.Clear();
00028 return input_;
00029 }
00030
00031
00032 template <class Filter> void CallFilter(Filter &filter) {
00033 input_.CallFilter(filter, output_);
00034 }
00035
00036 uint64_t Sequence() const { return sequence_; }
00037
00038
00039 template <class RealOutput> void Flush(RealOutput &output) {
00040 output_.Flush(output);
00041 }
00042
00043 private:
00044 InputBuffer input_;
00045 OutputBuffer output_;
00046
00047 uint64_t sequence_;
00048 };
00049
00050 template <class Batch, class Filter> class FilterWorker {
00051 public:
00052 typedef Batch *Request;
00053
00054 FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
00055
00056 void operator()(Request request) {
00057 request->CallFilter(filter_);
00058 done_.Produce(request);
00059 }
00060
00061 private:
00062 Filter filter_;
00063
00064 util::PCQueue<Request> &done_;
00065 };
00066
00067
00068 template <class Batch, class Output> class OutputWorker {
00069 public:
00070 typedef Batch *Request;
00071
00072 OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
00073
00074 void operator()(Request request) {
00075 assert(request->Sequence() >= base_sequence_);
00076
00077 uint64_t pos = request->Sequence() - base_sequence_;
00078 if (pos >= ordering_.size()) {
00079 ordering_.resize(pos + 1, NULL);
00080 }
00081 ordering_[pos] = request;
00082 while (!ordering_.empty() && ordering_.front()) {
00083 ordering_.front()->Flush(output_);
00084 done_.Produce(ordering_.front());
00085 ordering_.pop_front();
00086 ++base_sequence_;
00087 }
00088 }
00089
00090 private:
00091 Output &output_;
00092
00093 util::PCQueue<Request> &done_;
00094
00095 std::deque<Request> ordering_;
00096
00097 uint64_t base_sequence_;
00098 };
00099
00100 template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
00101 private:
00102 typedef ThreadBatch<OutputBuffer> Batch;
00103
00104 public:
00105 Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
00106 : batch_size_(batch_size), queue_size_(queue),
00107 batches_(queue),
00108 to_read_(queue),
00109 output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
00110 filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
00111 sequence_(0) {
00112 for (size_t i = 0; i < queue; ++i) {
00113 batches_[i].Reserve(batch_size);
00114 local_read_.push(&batches_[i]);
00115 }
00116 NewInput();
00117 }
00118
00119 void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
00120 input_->AddNGram(ngram, line, output);
00121 if (input_->Size() == batch_size_) {
00122 FlushInput();
00123 NewInput();
00124 }
00125 }
00126
00127 void Flush() {
00128 FlushInput();
00129 while (local_read_.size() < queue_size_) {
00130 MoveRead();
00131 }
00132 NewInput();
00133 }
00134
00135 private:
00136 void FlushInput() {
00137 if (input_->Empty()) return;
00138 filter_.Produce(local_read_.top());
00139 local_read_.pop();
00140 if (local_read_.empty()) MoveRead();
00141 }
00142
00143 void NewInput() {
00144 input_ = &local_read_.top()->Fill(sequence_++);
00145 }
00146
00147 void MoveRead() {
00148 local_read_.push(to_read_.Consume());
00149 }
00150
00151 const size_t batch_size_;
00152 const size_t queue_size_;
00153
00154 std::vector<Batch> batches_;
00155
00156 util::PCQueue<Batch*> to_read_;
00157 std::stack<Batch*> local_read_;
00158 util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
00159 util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
00160
00161 uint64_t sequence_;
00162 InputBuffer *input_;
00163 };
00164
00165 }
00166
00167 #endif // LM_FILTER_THREAD_H