00001 #include "Manager.h"
00002 #include "Util.h"
00003 #include "SearchCubePruning.h"
00004 #include "StaticData.h"
00005 #include "InputType.h"
00006 #include "TranslationOptionCollection.h"
00007 #include <boost/foreach.hpp>
00008 using namespace std;
00009 
00010 namespace Moses
00011 {
00012 class BitmapContainerOrderer
00013 {
00014 public:
00015   bool operator()(const BitmapContainer* A, const BitmapContainer* B) const {
00016     if (B->Empty()) {
00017       if (A->Empty()) {
00018         return A < B;
00019       }
00020       return false;
00021     }
00022     if (A->Empty()) {
00023       return true;
00024     }
00025 
00026     
00027     const float scoreA = A->Top()->GetHypothesis()->GetFutureScore();
00028     const float scoreB = B->Top()->GetHypothesis()->GetFutureScore();
00029 
00030     if (scoreA < scoreB) {
00031       return true;
00032     } else if (scoreA > scoreB) {
00033       return false;
00034     } else {
00035       
00036       
00037       
00038       
00039       
00040       
00041       boost::shared_ptr<TargetPhrase> phrA = A->Top()->GetTargetPhrase();
00042       boost::shared_ptr<TargetPhrase> phrB = B->Top()->GetTargetPhrase();
00043       if (!phrA || !phrB) {
00044         
00045         return A < B;
00046       }
00047       return (phrA->Compare(*phrB) > 0);
00048     }
00049   }
00050 };
00051 
00052 SearchCubePruning::
00053 SearchCubePruning(Manager& manager, TranslationOptionCollection const& transOptColl)
00054   : Search(manager)
00055   , m_hypoStackColl(manager.GetSource().GetSize() + 1)
00056   , m_transOptColl(transOptColl)
00057 {
00058   std::vector < HypothesisStackCubePruning >::iterator iterStack;
00059   for (size_t ind = 0 ; ind < m_hypoStackColl.size() ; ++ind) {
00060     HypothesisStackCubePruning *sourceHypoColl = new HypothesisStackCubePruning(m_manager);
00061     sourceHypoColl->SetMaxHypoStackSize(m_options.search.stack_size);
00062     sourceHypoColl->SetBeamWidth(m_options.search.beam_width);
00063 
00064     m_hypoStackColl[ind] = sourceHypoColl;
00065   }
00066 }
00067 
00068 SearchCubePruning::~SearchCubePruning()
00069 {
00070   RemoveAllInColl(m_hypoStackColl);
00071 }
00072 
00077 void SearchCubePruning::Decode()
00078 {
00079   
00080   const Bitmap &initBitmap = m_bitmaps.GetInitialBitmap();
00081   Hypothesis *hypo = new Hypothesis(m_manager, m_source, m_initialTransOpt, initBitmap, m_manager.GetNextHypoId());
00082 
00083   HypothesisStackCubePruning &firstStack
00084   = *static_cast<HypothesisStackCubePruning*>(m_hypoStackColl.front());
00085   firstStack.AddInitial(hypo);
00086   
00087   firstStack.CleanupArcList();
00088   CreateForwardTodos(firstStack);
00089 
00090   const size_t PopLimit = m_manager.options()->cube.pop_limit;
00091   VERBOSE(2,"Cube Pruning pop limit is " << PopLimit << std::endl);
00092 
00093   const size_t Diversity = m_manager.options()->cube.diversity;
00094   VERBOSE(2,"Cube Pruning diversity is " << Diversity << std::endl);
00095   VERBOSE(2,"Max Phrase length is "
00096           << m_manager.options()->search.max_phrase_length << std::endl);
00097 
00098   
00099   size_t stackNo = 1;
00100   std::vector < HypothesisStack* >::iterator iterStack;
00101   for (iterStack = m_hypoStackColl.begin() + 1 ; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00102     
00103     if (this->out_of_time()) return;
00104 
00105     HypothesisStackCubePruning &sourceHypoColl
00106     = *static_cast<HypothesisStackCubePruning*>(*iterStack);
00107 
00108     
00109     
00110     std::priority_queue < BitmapContainer*, std::vector< BitmapContainer* >,
00111         BitmapContainerOrderer > BCQueue;
00112 
00113     _BMType::const_iterator bmIter;
00114     const _BMType &accessor = sourceHypoColl.GetBitmapAccessor();
00115 
00116     for(bmIter = accessor.begin(); bmIter != accessor.end(); ++bmIter) {
00117       
00118       IFVERBOSE(2) {
00119         m_manager.GetSentenceStats().StartTimeOtherScore();
00120       }
00121       bmIter->second->InitializeEdges();
00122       IFVERBOSE(2) {
00123         m_manager.GetSentenceStats().StopTimeOtherScore();
00124       }
00125       m_manager.GetSentenceStats().StartTimeManageCubes();
00126       BCQueue.push(bmIter->second);
00127       m_manager.GetSentenceStats().StopTimeManageCubes();
00128 
00129     }
00130 
00131     
00132     for (size_t numpops = 1; numpops <= PopLimit && !BCQueue.empty(); numpops++) {
00133       
00134       m_manager.GetSentenceStats().StartTimeManageCubes();
00135       BitmapContainer *bc = BCQueue.top();
00136       BCQueue.pop();
00137       m_manager.GetSentenceStats().StopTimeManageCubes();
00138       IFVERBOSE(2) {
00139         m_manager.GetSentenceStats().AddPopped();
00140       }
00141       
00142       IFVERBOSE(2) {
00143         m_manager.GetSentenceStats().StartTimeOtherScore();
00144       }
00145       bc->ProcessBestHypothesis();
00146       IFVERBOSE(2) {
00147         m_manager.GetSentenceStats().StopTimeOtherScore();
00148       }
00149       
00150       m_manager.GetSentenceStats().StartTimeManageCubes();
00151       if (!bc->Empty())
00152         BCQueue.push(bc);
00153       m_manager.GetSentenceStats().StopTimeManageCubes();
00154     }
00155 
00156     
00157     
00158     if (Diversity > 0) {
00159       IFVERBOSE(2) {
00160         m_manager.GetSentenceStats().StartTimeOtherScore();
00161       }
00162       for(bmIter = accessor.begin(); bmIter != accessor.end(); ++bmIter) {
00163         bmIter->second->EnsureMinStackHyps(Diversity);
00164       }
00165       IFVERBOSE(2) {
00166         m_manager.GetSentenceStats().StopTimeOtherScore();
00167       }
00168     }
00169 
00170     
00171     VERBOSE(3,"processing hypothesis from next stack");
00172     IFVERBOSE(2) {
00173       m_manager.GetSentenceStats().StartTimeStack();
00174     }
00175     sourceHypoColl.PruneToSize(m_options.search.stack_size);
00176     VERBOSE(3,std::endl);
00177     sourceHypoColl.CleanupArcList();
00178     IFVERBOSE(2) {
00179       m_manager.GetSentenceStats().StopTimeStack();
00180     }
00181 
00182     IFVERBOSE(2) {
00183       m_manager.GetSentenceStats().StartTimeSetupCubes();
00184     }
00185     CreateForwardTodos(sourceHypoColl);
00186     IFVERBOSE(2) {
00187       m_manager.GetSentenceStats().StopTimeSetupCubes();
00188     }
00189 
00190     stackNo++;
00191   }
00192 }
00193 
00194 void SearchCubePruning::CreateForwardTodos(HypothesisStackCubePruning &stack)
00195 {
00196   const _BMType &bitmapAccessor = stack.GetBitmapAccessor();
00197   _BMType::const_iterator iterAccessor;
00198   size_t size = m_source.GetSize();
00199 
00200   stack.AddHypothesesToBitmapContainers();
00201 
00202   for (iterAccessor = bitmapAccessor.begin() ; iterAccessor != bitmapAccessor.end() ; ++iterAccessor) {
00203     const Bitmap &bitmap = *iterAccessor->first;
00204     BitmapContainer &bitmapContainer = *iterAccessor->second;
00205 
00206     if (bitmapContainer.GetHypothesesSize() == 0) {
00207       
00208       continue;
00209     }
00210 
00211     
00212     bitmapContainer.SortHypotheses();
00213 
00214     
00215     size_t startPos, endPos;
00216     for (startPos = 0 ; startPos < size ; startPos++) {
00217       if (bitmap.GetValue(startPos))
00218         continue;
00219 
00220       
00221       Range applyRange(startPos, startPos);
00222       if (CheckDistortion(bitmap, applyRange)) {
00223         
00224         CreateForwardTodos(bitmap, applyRange, bitmapContainer);
00225       }
00226 
00227       size_t maxSize = size - startPos;
00228       size_t maxSizePhrase = m_manager.options()->search.max_phrase_length;
00229       maxSize = std::min(maxSize, maxSizePhrase);
00230       for (endPos = startPos+1; endPos < startPos + maxSize; endPos++) {
00231         if (bitmap.GetValue(endPos))
00232           break;
00233 
00234         Range applyRange(startPos, endPos);
00235         if (CheckDistortion(bitmap, applyRange)) {
00236           
00237           CreateForwardTodos(bitmap, applyRange, bitmapContainer);
00238         }
00239       }
00240     }
00241   }
00242 }
00243 
00244 void
00245 SearchCubePruning::
00246 CreateForwardTodos(Bitmap const& bitmap, Range const& range,
00247                    BitmapContainer& bitmapContainer)
00248 {
00249   const Bitmap &newBitmap = m_bitmaps.GetBitmap(bitmap, range);
00250 
00251   size_t numCovered = newBitmap.GetNumWordsCovered();
00252   const TranslationOptionList* transOptList;
00253   transOptList = m_transOptColl.GetTranslationOptionList(range);
00254   const SquareMatrix &estimatedScores = m_transOptColl.GetEstimatedScores();
00255 
00256   if (transOptList && transOptList->size() > 0) {
00257     HypothesisStackCubePruning& newStack
00258     = *static_cast<HypothesisStackCubePruning*>(m_hypoStackColl[numCovered]);
00259     newStack.SetBitmapAccessor(newBitmap, newStack, range, bitmapContainer,
00260                                estimatedScores, *transOptList);
00261   }
00262 }
00263 
00264 bool
00265 SearchCubePruning::
00266 CheckDistortion(const Bitmap &hypoBitmap, const Range &range) const
00267 {
00268   
00269   int maxDistortion = m_manager.options()->reordering.max_distortion;
00270   if (maxDistortion < 0) return true;
00271 
00272   
00273   
00274   size_t const startPos = range.GetStartPos();
00275   size_t const endPos = range.GetEndPos();
00276 
00277   
00278   
00279   if (!m_source.GetReorderingConstraint().Check(hypoBitmap, startPos, endPos))
00280     return false;
00281 
00282   size_t const hypoFirstGapPos = hypoBitmap.GetFirstGapPos();
00283   
00284   if (hypoFirstGapPos == startPos) return true;
00285 
00286   
00287   
00288   
00289   
00290   
00291   
00292   
00293   
00294   Range bestNextExtension(hypoFirstGapPos, hypoFirstGapPos);
00295   return (m_source.ComputeDistortionDistance(range, bestNextExtension)
00296           <= maxDistortion);
00297 }
00298 
00303 Hypothesis const*
00304 SearchCubePruning::
00305 GetBestHypothesis() const
00306 {
00307   
00308   const HypothesisStack &hypoColl = *m_hypoStackColl.back();
00309   return hypoColl.GetBestHypothesis();
00310 }
00311 
00315 void
00316 SearchCubePruning::
00317 OutputHypoStackSize()
00318 {
00319   std::vector < HypothesisStack* >::const_iterator iterStack = m_hypoStackColl.begin();
00320   TRACE_ERR( "Stack sizes: " << (int)(*iterStack)->size());
00321   for (++iterStack; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00322     TRACE_ERR( ", " << (int)(*iterStack)->size());
00323   }
00324   TRACE_ERR( endl);
00325 }
00326 
00327 void SearchCubePruning::PrintBitmapContainerGraph()
00328 {
00329   HypothesisStackCubePruning &lastStack = *static_cast<HypothesisStackCubePruning*>(m_hypoStackColl.back());
00330   const _BMType &bitmapAccessor = lastStack.GetBitmapAccessor();
00331 
00332   _BMType::const_iterator iterAccessor;
00333   for (iterAccessor = bitmapAccessor.begin(); iterAccessor != bitmapAccessor.end(); ++iterAccessor) {
00334     cerr << iterAccessor->first << endl;
00335     
00336   }
00337 
00338 }
00339 
00344 void SearchCubePruning::OutputHypoStack(int stack)
00345 {
00346   if (stack >= 0) {
00347     TRACE_ERR( "Stack " << stack << ": " << endl << m_hypoStackColl[stack] << endl);
00348   } else {
00349     
00350     int i = 0;
00351     vector < HypothesisStack* >::iterator iterStack;
00352     for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00353       HypothesisStackCubePruning &hypoColl = *static_cast<HypothesisStackCubePruning*>(*iterStack);
00354       TRACE_ERR( "Stack " << i++ << ": " << endl << hypoColl << endl);
00355     }
00356   }
00357 }
00358 
00359 const std::vector < HypothesisStack* >& SearchCubePruning::GetHypothesisStacks() const
00360 {
00361   return m_hypoStackColl;
00362 }
00363 
00364 }
00365