00001 #include "lm/builder/adjust_counts.hh"
00002
00003 #include "lm/common/ngram_stream.hh"
00004 #include "lm/builder/payload.hh"
00005 #include "util/scoped.hh"
00006
00007 #include <boost/thread/thread.hpp>
00008 #define BOOST_TEST_MODULE AdjustCounts
00009 #include <boost/test/unit_test.hpp>
00010
00011 namespace lm { namespace builder { namespace {
00012
00013 class KeepCopy {
00014 public:
00015 KeepCopy() : size_(0) {}
00016
00017 void Run(const util::stream::ChainPosition &position) {
00018 for (util::stream::Link link(position); link; ++link) {
00019 mem_.call_realloc(size_ + link->ValidSize());
00020 memcpy(static_cast<uint8_t*>(mem_.get()) + size_, link->Get(), link->ValidSize());
00021 size_ += link->ValidSize();
00022 }
00023 }
00024
00025 uint8_t *Get() { return static_cast<uint8_t*>(mem_.get()); }
00026 std::size_t Size() const { return size_; }
00027
00028 private:
00029 util::scoped_malloc mem_;
00030 std::size_t size_;
00031 };
00032
00033 struct Gram4 {
00034 WordIndex ids[4];
00035 uint64_t count;
00036 };
00037
00038 class WriteInput {
00039 public:
00040 void Run(const util::stream::ChainPosition &position) {
00041 NGramStream<BuildingPayload> input(position);
00042 Gram4 grams[] = {
00043 {{0,0,0,0},10},
00044 {{0,0,3,0},3},
00045
00046 {{1,1,1,2},5},
00047 {{0,0,3,2},5},
00048 };
00049 for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) {
00050 memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4);
00051 input->Value().count = grams[i].count;
00052 }
00053 input.Poison();
00054 }
00055 };
00056
00057 BOOST_AUTO_TEST_CASE(Simple) {
00058 KeepCopy outputs[4];
00059 std::vector<uint64_t> counts;
00060 std::vector<Discount> discount;
00061 {
00062 util::stream::ChainConfig config;
00063 config.total_memory = 100;
00064 config.block_count = 1;
00065 util::stream::Chains chains(4);
00066 for (unsigned i = 0; i < 4; ++i) {
00067 config.entry_size = NGram<BuildingPayload>::TotalSize(i + 1);
00068 chains.push_back(config);
00069 }
00070
00071 chains[3] >> WriteInput();
00072 util::stream::ChainPositions for_adjust(chains);
00073 for (unsigned i = 0; i < 4; ++i) {
00074 chains[i] >> boost::ref(outputs[i]);
00075 }
00076 chains >> util::stream::kRecycle;
00077 std::vector<uint64_t> counts_pruned(4);
00078 std::vector<uint64_t> prune_thresholds(4);
00079 DiscountConfig discount_config;
00080 discount_config.fallback = Discount();
00081 discount_config.bad_action = THROW_UP;
00082 BOOST_CHECK_THROW(AdjustCounts(prune_thresholds, counts, counts_pruned, std::vector<bool>(), discount_config, discount).Run(for_adjust), BadDiscountException);
00083 }
00084 BOOST_REQUIRE_EQUAL(4UL, counts.size());
00085 BOOST_CHECK_EQUAL(4UL, counts[0]);
00086
00087
00088
00089
00090 BOOST_REQUIRE_EQUAL(NGram<BuildingPayload>::TotalSize(1) * 4, outputs[0].Size());
00091 NGram<BuildingPayload> uni(outputs[0].Get(), 1);
00092 BOOST_CHECK_EQUAL(kUNK, *uni.begin());
00093 BOOST_CHECK_EQUAL(0ULL, uni.Value().count);
00094 uni.NextInMemory();
00095 BOOST_CHECK_EQUAL(kBOS, *uni.begin());
00096 BOOST_CHECK_EQUAL(0ULL, uni.Value().count);
00097 uni.NextInMemory();
00098 BOOST_CHECK_EQUAL(0UL, *uni.begin());
00099 BOOST_CHECK_EQUAL(2ULL, uni.Value().count);
00100 uni.NextInMemory();
00101 BOOST_CHECK_EQUAL(2ULL, uni.Value().count);
00102 BOOST_CHECK_EQUAL(2UL, *uni.begin());
00103
00104 BOOST_REQUIRE_EQUAL(NGram<BuildingPayload>::TotalSize(2) * 4, outputs[1].Size());
00105 NGram<BuildingPayload> bi(outputs[1].Get(), 2);
00106 BOOST_CHECK_EQUAL(0UL, *bi.begin());
00107 BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1));
00108 BOOST_CHECK_EQUAL(1ULL, bi.Value().count);
00109 bi.NextInMemory();
00110 }
00111
00112 }}}