00001 #include "lm/left.hh"
00002 #include "lm/model.hh"
00003
00004 #include "util/tokenize_piece.hh"
00005
00006 #include <vector>
00007
00008 #define BOOST_TEST_MODULE LeftTest
00009 #include <boost/test/unit_test.hpp>
00010 #include <boost/test/floating_point_comparison.hpp>
00011
00012 namespace lm {
00013 namespace ngram {
00014 namespace {
00015
00016 #define Term(word) score.Terminal(m.GetVocabulary().Index(word));
00017 #define VCheck(word, value) BOOST_CHECK_EQUAL(m.GetVocabulary().Index(word), value);
00018
00019
00020 #define SLOPPY_CHECK_CLOSE(ref, value, tol) BOOST_CHECK_CLOSE(static_cast<double>(ref), static_cast<double>(value), static_cast<double>(tol));
00021
00022 template <class M> void Short(const M &m) {
00023 ChartState base;
00024 {
00025 RuleScore<M> score(m, base);
00026 Term("more");
00027 Term("loin");
00028 SLOPPY_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001);
00029 }
00030 BOOST_CHECK(base.left.full);
00031 BOOST_CHECK_EQUAL(2, base.left.length);
00032 BOOST_CHECK_EQUAL(1, base.right.length);
00033 VCheck("loin", base.right.words[0]);
00034
00035 ChartState more_left;
00036 {
00037 RuleScore<M> score(m, more_left);
00038 Term("little");
00039 score.NonTerminal(base, -1.206319 - 0.3561665);
00040
00041 SLOPPY_CHECK_CLOSE(-1.56538, score.Finish(), 0.001);
00042 }
00043 BOOST_CHECK_EQUAL(3, more_left.left.length);
00044 BOOST_CHECK_EQUAL(1, more_left.right.length);
00045 VCheck("loin", more_left.right.words[0]);
00046 BOOST_CHECK(more_left.left.full);
00047
00048 ChartState shorter;
00049 {
00050 RuleScore<M> score(m, shorter);
00051 Term("to");
00052 score.NonTerminal(base, -1.206319 - 0.3561665);
00053 SLOPPY_CHECK_CLOSE(-0.30103 - 1.687872 - 1.206319 - 0.3561665, score.Finish(), 0.01);
00054 }
00055 BOOST_CHECK_EQUAL(1, shorter.left.length);
00056 BOOST_CHECK_EQUAL(1, shorter.right.length);
00057 VCheck("loin", shorter.right.words[0]);
00058 BOOST_CHECK(shorter.left.full);
00059 }
00060
00061 template <class M> void Charge(const M &m) {
00062 ChartState base;
00063 {
00064 RuleScore<M> score(m, base);
00065 Term("on");
00066 Term("more");
00067 SLOPPY_CHECK_CLOSE(-1.509559 -0.4771212 -1.206319, score.Finish(), 0.001);
00068 }
00069 BOOST_CHECK_EQUAL(1, base.left.length);
00070 BOOST_CHECK_EQUAL(1, base.right.length);
00071 VCheck("more", base.right.words[0]);
00072 BOOST_CHECK(base.left.full);
00073
00074 ChartState extend;
00075 {
00076 RuleScore<M> score(m, extend);
00077 Term("looking");
00078 score.NonTerminal(base, -1.509559 -0.4771212 -1.206319);
00079 SLOPPY_CHECK_CLOSE(-3.91039, score.Finish(), 0.001);
00080 }
00081 BOOST_CHECK_EQUAL(2, extend.left.length);
00082 BOOST_CHECK_EQUAL(1, extend.right.length);
00083 VCheck("more", extend.right.words[0]);
00084 BOOST_CHECK(extend.left.full);
00085
00086 ChartState tobos;
00087 {
00088 RuleScore<M> score(m, tobos);
00089 score.BeginSentence();
00090 score.NonTerminal(extend, -3.91039);
00091 SLOPPY_CHECK_CLOSE(-3.471169, score.Finish(), 0.001);
00092 }
00093 BOOST_CHECK_EQUAL(0, tobos.left.length);
00094 BOOST_CHECK_EQUAL(1, tobos.right.length);
00095 }
00096
00097 template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
00098 float ret = 0.0;
00099 State right = begin_sentence ? m.BeginSentenceState() : m.NullContextState();
00100 for (std::vector<WordIndex>::const_iterator i = words.begin(); i != words.end(); ++i) {
00101 State copy(right);
00102 ret += m.Score(copy, *i, right);
00103 }
00104 return ret;
00105 }
00106
00107 template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
00108 float ret = 0.0;
00109 ChartState state;
00110 state.left.length = 0;
00111 state.right.length = 0;
00112 state.left.full = false;
00113 for (std::vector<WordIndex>::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) {
00114 ChartState copy(state);
00115 RuleScore<M> score(m, state);
00116 score.Terminal(*i);
00117 score.NonTerminal(copy, ret);
00118 ret = score.Finish();
00119 }
00120 if (begin_sentence) {
00121 ChartState copy(state);
00122 RuleScore<M> score(m, state);
00123 score.BeginSentence();
00124 score.NonTerminal(copy, ret);
00125 ret = score.Finish();
00126 }
00127 return ret;
00128 }
00129
00130 template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
00131 std::vector<std::pair<ChartState, float> > states(words.size());
00132 for (unsigned int i = 0; i < words.size(); ++i) {
00133 RuleScore<M> score(m, states[i].first);
00134 score.Terminal(words[i]);
00135 states[i].second = score.Finish();
00136 }
00137 while (states.size() > 1) {
00138 std::vector<std::pair<ChartState, float> > upper((states.size() + 1) / 2);
00139 for (unsigned int i = 0; i < states.size() / 2; ++i) {
00140 RuleScore<M> score(m, upper[i].first);
00141 score.NonTerminal(states[i*2].first, states[i*2].second);
00142 score.NonTerminal(states[i*2+1].first, states[i*2+1].second);
00143 upper[i].second = score.Finish();
00144 }
00145 if (states.size() % 2) {
00146 upper.back() = states.back();
00147 }
00148 std::swap(states, upper);
00149 }
00150
00151 if (states.empty()) return 0.0;
00152
00153 if (begin_sentence) {
00154 ChartState ignored;
00155 RuleScore<M> score(m, ignored);
00156 score.BeginSentence();
00157 score.NonTerminal(states.front().first, states.front().second);
00158 return score.Finish();
00159 } else {
00160 return states.front().second;
00161 }
00162
00163 }
00164
00165 template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) {
00166 out.clear();
00167 for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
00168 out.push_back(m.GetVocabulary().Index(*i));
00169 }
00170 }
00171
00172 #define TEXT_TEST(str) \
00173 LookupVocab(m, str, words); \
00174 expect = LeftToRight(m, words, rest); \
00175 SLOPPY_CHECK_CLOSE(expect, RightToLeft(m, words, rest), 0.001); \
00176 SLOPPY_CHECK_CLOSE(expect, TreeMiddle(m, words, rest), 0.001); \
00177
00178
00179 template <class M> void GrowBig(const M &m, bool rest = false) {
00180 std::vector<WordIndex> words;
00181 float expect;
00182 TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
00183 TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
00184 TEXT_TEST("on a little more loin also would consider higher to look good");
00185 TEXT_TEST("more loin also would consider higher to look good");
00186 TEXT_TEST("more loin also would consider higher to look");
00187 TEXT_TEST("also would consider higher to look");
00188 TEXT_TEST("also would consider higher");
00189 TEXT_TEST("would consider higher to look");
00190 TEXT_TEST("consider higher to look");
00191 TEXT_TEST("consider higher to");
00192 TEXT_TEST("consider higher");
00193 }
00194
00195 template <class M> void GrowSmall(const M &m, bool rest = false) {
00196 std::vector<WordIndex> words;
00197 float expect;
00198 TEXT_TEST("in biarritz watching considering looking . </s>");
00199 TEXT_TEST("in biarritz watching considering looking .");
00200 TEXT_TEST("in biarritz");
00201 }
00202
00203 template <class M> void AlsoWouldConsiderHigher(const M &m) {
00204 ChartState also;
00205 {
00206 RuleScore<M> score(m, also);
00207 score.Terminal(m.GetVocabulary().Index("also"));
00208 SLOPPY_CHECK_CLOSE(-1.687872, score.Finish(), 0.001);
00209 }
00210 ChartState would;
00211 {
00212 RuleScore<M> score(m, would);
00213 score.Terminal(m.GetVocabulary().Index("would"));
00214 SLOPPY_CHECK_CLOSE(-1.687872, score.Finish(), 0.001);
00215 }
00216 ChartState combine_also_would;
00217 {
00218 RuleScore<M> score(m, combine_also_would);
00219 score.NonTerminal(also, -1.687872);
00220 score.NonTerminal(would, -1.687872);
00221 SLOPPY_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001);
00222 }
00223 BOOST_CHECK_EQUAL(2, combine_also_would.right.length);
00224
00225 ChartState also_would;
00226 {
00227 RuleScore<M> score(m, also_would);
00228 score.Terminal(m.GetVocabulary().Index("also"));
00229 score.Terminal(m.GetVocabulary().Index("would"));
00230 SLOPPY_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001);
00231 }
00232 BOOST_CHECK_EQUAL(2, also_would.right.length);
00233
00234 ChartState consider;
00235 {
00236 RuleScore<M> score(m, consider);
00237 score.Terminal(m.GetVocabulary().Index("consider"));
00238 SLOPPY_CHECK_CLOSE(-1.687872, score.Finish(), 0.001);
00239 }
00240 BOOST_CHECK_EQUAL(1, consider.left.length);
00241 BOOST_CHECK_EQUAL(1, consider.right.length);
00242 BOOST_CHECK(!consider.left.full);
00243
00244 ChartState higher;
00245 float higher_score;
00246 {
00247 RuleScore<M> score(m, higher);
00248 score.Terminal(m.GetVocabulary().Index("higher"));
00249 higher_score = score.Finish();
00250 }
00251 SLOPPY_CHECK_CLOSE(-1.509559, higher_score, 0.001);
00252 BOOST_CHECK_EQUAL(1, higher.left.length);
00253 BOOST_CHECK_EQUAL(1, higher.right.length);
00254 BOOST_CHECK(!higher.left.full);
00255 VCheck("higher", higher.right.words[0]);
00256 SLOPPY_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001);
00257
00258 ChartState consider_higher;
00259 {
00260 RuleScore<M> score(m, consider_higher);
00261 score.NonTerminal(consider, -1.687872);
00262 score.NonTerminal(higher, higher_score);
00263 SLOPPY_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001);
00264 }
00265 BOOST_CHECK_EQUAL(2, consider_higher.left.length);
00266 BOOST_CHECK(!consider_higher.left.full);
00267
00268 ChartState full;
00269 {
00270 RuleScore<M> score(m, full);
00271 score.NonTerminal(combine_also_would, -1.687872 - 2.0);
00272 score.NonTerminal(consider_higher, -1.509559 - 1.687872 - 0.30103);
00273 SLOPPY_CHECK_CLOSE(-10.6879, score.Finish(), 0.001);
00274 }
00275 BOOST_CHECK_EQUAL(4, full.right.length);
00276 }
00277
00278 #define CHECK_SCORE(str, val) \
00279 { \
00280 float got = val; \
00281 std::vector<WordIndex> indices; \
00282 LookupVocab(m, str, indices); \
00283 SLOPPY_CHECK_CLOSE(LeftToRight(m, indices), got, 0.001); \
00284 }
00285
00286 template <class M> void FullGrow(const M &m) {
00287 std::vector<WordIndex> words;
00288 LookupVocab(m, "in biarritz watching considering looking . </s>", words);
00289
00290 ChartState lexical[7];
00291 float lexical_scores[7];
00292 for (unsigned int i = 0; i < 7; ++i) {
00293 RuleScore<M> score(m, lexical[i]);
00294 score.Terminal(words[i]);
00295 lexical_scores[i] = score.Finish();
00296 }
00297 CHECK_SCORE("in", lexical_scores[0]);
00298 CHECK_SCORE("biarritz", lexical_scores[1]);
00299 CHECK_SCORE("watching", lexical_scores[2]);
00300 CHECK_SCORE("</s>", lexical_scores[6]);
00301
00302 ChartState l1[4];
00303 float l1_scores[4];
00304 {
00305 RuleScore<M> score(m, l1[0]);
00306 score.NonTerminal(lexical[0], lexical_scores[0]);
00307 score.NonTerminal(lexical[1], lexical_scores[1]);
00308 CHECK_SCORE("in biarritz", l1_scores[0] = score.Finish());
00309 }
00310 {
00311 RuleScore<M> score(m, l1[1]);
00312 score.NonTerminal(lexical[2], lexical_scores[2]);
00313 score.NonTerminal(lexical[3], lexical_scores[3]);
00314 CHECK_SCORE("watching considering", l1_scores[1] = score.Finish());
00315 }
00316 {
00317 RuleScore<M> score(m, l1[2]);
00318 score.NonTerminal(lexical[4], lexical_scores[4]);
00319 score.NonTerminal(lexical[5], lexical_scores[5]);
00320 CHECK_SCORE("looking .", l1_scores[2] = score.Finish());
00321 }
00322 BOOST_CHECK_EQUAL(l1[2].left.length, 1);
00323 l1[3] = lexical[6];
00324 l1_scores[3] = lexical_scores[6];
00325
00326 ChartState l2[2];
00327 float l2_scores[2];
00328 {
00329 RuleScore<M> score(m, l2[0]);
00330 score.NonTerminal(l1[0], l1_scores[0]);
00331 score.NonTerminal(l1[1], l1_scores[1]);
00332 CHECK_SCORE("in biarritz watching considering", l2_scores[0] = score.Finish());
00333 }
00334 {
00335 RuleScore<M> score(m, l2[1]);
00336 score.NonTerminal(l1[2], l1_scores[2]);
00337 score.NonTerminal(l1[3], l1_scores[3]);
00338 CHECK_SCORE("looking . </s>", l2_scores[1] = score.Finish());
00339 }
00340 BOOST_CHECK_EQUAL(l2[1].left.length, 1);
00341 BOOST_CHECK(l2[1].left.full);
00342
00343 ChartState top;
00344 {
00345 RuleScore<M> score(m, top);
00346 score.NonTerminal(l2[0], l2_scores[0]);
00347 score.NonTerminal(l2[1], l2_scores[1]);
00348 CHECK_SCORE("in biarritz watching considering looking . </s>", score.Finish());
00349 }
00350 }
00351
00352 const char *FileLocation() {
00353 if (boost::unit_test::framework::master_test_suite().argc < 2) {
00354 return "test.arpa";
00355 }
00356 return boost::unit_test::framework::master_test_suite().argv[1];
00357 }
00358
00359 template <class M> void Everything() {
00360 Config config;
00361 config.messages = NULL;
00362 M m(FileLocation(), config);
00363
00364 Short(m);
00365 Charge(m);
00366 GrowBig(m);
00367 AlsoWouldConsiderHigher(m);
00368 GrowSmall(m);
00369 FullGrow(m);
00370 }
00371
00372 BOOST_AUTO_TEST_CASE(ProbingAll) {
00373 Everything<Model>();
00374 }
00375 BOOST_AUTO_TEST_CASE(TrieAll) {
00376 Everything<TrieModel>();
00377 }
00378 BOOST_AUTO_TEST_CASE(QuantTrieAll) {
00379 Everything<QuantTrieModel>();
00380 }
00381 BOOST_AUTO_TEST_CASE(ArrayQuantTrieAll) {
00382 Everything<QuantArrayTrieModel>();
00383 }
00384 BOOST_AUTO_TEST_CASE(ArrayTrieAll) {
00385 Everything<ArrayTrieModel>();
00386 }
00387
00388 BOOST_AUTO_TEST_CASE(RestProbing) {
00389 Config config;
00390 config.messages = NULL;
00391 RestProbingModel m(FileLocation(), config);
00392 GrowBig(m, true);
00393 }
00394
00395 }
00396 }
00397 }