00001 #include "ScoreStsg.h"
00002
00003 #include <cassert>
00004 #include <cstdlib>
00005 #include <fstream>
00006 #include <iostream>
00007 #include <iterator>
00008 #include <string>
00009 #include <sstream>
00010 #include <vector>
00011
00012 #include <boost/program_options.hpp>
00013
00014 #include "util/string_piece.hh"
00015 #include "util/string_piece_hash.hh"
00016 #include "util/tokenize_piece.hh"
00017
00018 #include "InputFileStream.h"
00019 #include "OutputFileStream.h"
00020
00021 #include "syntax-common/exception.h"
00022
00023 #include "LexicalTable.h"
00024 #include "Options.h"
00025 #include "RuleGroup.h"
00026 #include "RuleTableWriter.h"
00027
00028 namespace MosesTraining
00029 {
00030 namespace Syntax
00031 {
00032 namespace ScoreStsg
00033 {
00034
00035 const int ScoreStsg::kCountOfCountsMax = 10;
00036
00037 ScoreStsg::ScoreStsg()
00038 : Tool("score-stsg")
00039 , m_lexTable(m_srcVocab, m_tgtVocab)
00040 , m_countOfCounts(kCountOfCountsMax, 0)
00041 , m_totalDistinct(0)
00042 {
00043 }
00044
00045 int ScoreStsg::Main(int argc, char *argv[])
00046 {
00047
00048 ProcessOptions(argc, argv, m_options);
00049
00050
00051 Moses::InputFileStream extractStream(m_options.extractFile);
00052 Moses::InputFileStream lexStream(m_options.lexFile);
00053
00054
00055 Moses::OutputFileStream outStream;
00056 Moses::OutputFileStream countOfCountsStream;
00057 OpenOutputFileOrDie(m_options.tableFile, outStream);
00058 if (m_options.goodTuring || m_options.kneserNey) {
00059 OpenOutputFileOrDie(m_options.tableFile+".coc", countOfCountsStream);
00060 }
00061
00062
00063 if (!m_options.noLex) {
00064 m_lexTable.Load(lexStream);
00065 }
00066
00067 const util::MultiCharacter delimiter("|||");
00068 std::size_t lineNum = 0;
00069 std::size_t startLine= 0;
00070 std::string line;
00071 std::string tmp;
00072 RuleGroup ruleGroup;
00073 RuleTableWriter ruleTableWriter(m_options, outStream);
00074
00075 while (std::getline(extractStream, line)) {
00076 ++lineNum;
00077
00078
00079 util::TokenIter<util::MultiCharacter> it(line, delimiter);
00080 StringPiece source = *it++;
00081 StringPiece target = *it++;
00082 StringPiece ntAlign = *it++;
00083 StringPiece fullAlign = *it++;
00084 it->CopyToString(&tmp);
00085 int count = std::atoi(tmp.c_str());
00086 double treeScore = 0.0f;
00087 if (m_options.treeScore && !m_options.inverse) {
00088 ++it;
00089 it->CopyToString(&tmp);
00090 treeScore = std::atof(tmp.c_str());
00091 }
00092
00093
00094
00095 if (source != ruleGroup.GetSource()) {
00096 if (lineNum > 1) {
00097 ProcessRuleGroupOrDie(ruleGroup, ruleTableWriter, startLine, lineNum-1);
00098 }
00099 startLine = lineNum;
00100 ruleGroup.SetNewSource(source);
00101 }
00102
00103
00104 ruleGroup.AddRule(target, ntAlign, fullAlign, count, treeScore);
00105 }
00106
00107
00108 ProcessRuleGroupOrDie(ruleGroup, ruleTableWriter, startLine, lineNum);
00109
00110
00111 if (m_options.goodTuring || m_options.kneserNey) {
00112
00113 countOfCountsStream << m_totalDistinct << std::endl;
00114
00115 for (int i = 1; i <= kCountOfCountsMax; ++i) {
00116 countOfCountsStream << m_countOfCounts[i] << std::endl;
00117 }
00118 }
00119
00120 return 0;
00121 }
00122
00123 void ScoreStsg::TokenizeRuleHalf(const std::string &s, TokenizedRuleHalf &half)
00124 {
00125
00126 std::size_t start = s.find_first_not_of(" \t");
00127 if (start == std::string::npos) {
00128 throw Exception("rule half is empty");
00129 }
00130 std::size_t end = s.find_last_not_of(" \t");
00131 assert(end != std::string::npos);
00132 half.string = s.substr(start, end-start+1);
00133
00134
00135 half.tokens.clear();
00136 for (TreeFragmentTokenizer p(half.string);
00137 p != TreeFragmentTokenizer(); ++p) {
00138 half.tokens.push_back(*p);
00139 }
00140
00141
00142 half.frontierSymbols.clear();
00143 const std::size_t numTokens = half.tokens.size();
00144 for (int i = 0; i < numTokens; ++i) {
00145 if (half.tokens[i].type != TreeFragmentToken_WORD) {
00146 continue;
00147 }
00148 if (i == 0 || half.tokens[i-1].type != TreeFragmentToken_LSB) {
00149
00150 half.frontierSymbols.resize(half.frontierSymbols.size()+1);
00151 half.frontierSymbols.back().value = half.tokens[i].value;
00152 half.frontierSymbols.back().isNonTerminal = false;
00153 } else if (i+1 < numTokens &&
00154 half.tokens[i+1].type == TreeFragmentToken_RSB) {
00155
00156 half.frontierSymbols.resize(half.frontierSymbols.size()+1);
00157 half.frontierSymbols.back().value = half.tokens[i].value;
00158 half.frontierSymbols.back().isNonTerminal = true;
00159 ++i;
00160 }
00161 }
00162 }
00163
00164 void ScoreStsg::ProcessRuleGroupOrDie(const RuleGroup &group,
00165 RuleTableWriter &writer,
00166 std::size_t start,
00167 std::size_t end)
00168 {
00169 try {
00170 ProcessRuleGroup(group, writer);
00171 } catch (const Exception &e) {
00172 std::ostringstream msg;
00173 msg << "failed to process rule group at lines " << start << "-" << end
00174 << ": " << e.msg();
00175 Error(msg.str());
00176 } catch (const std::exception &e) {
00177 std::ostringstream msg;
00178 msg << "failed to process rule group at lines " << start << "-" << end
00179 << ": " << e.what();
00180 Error(msg.str());
00181 }
00182 }
00183
00184 void ScoreStsg::ProcessRuleGroup(const RuleGroup &group,
00185 RuleTableWriter &writer)
00186 {
00187 const std::size_t totalCount = group.GetTotalCount();
00188 const std::size_t distinctCount = group.GetSize();
00189
00190 TokenizeRuleHalf(group.GetSource(), m_sourceHalf);
00191
00192 const bool fullyLexical = m_sourceHalf.IsFullyLexical();
00193
00194
00195 for (RuleGroup::ConstIterator p = group.Begin(); p != group.End(); ++p) {
00196 const RuleGroup::DistinctRule &rule = *p;
00197
00198
00199 if (m_options.goodTuring || m_options.kneserNey) {
00200 ++m_totalDistinct;
00201 int countInt = rule.count + 0.99999;
00202 if (countInt <= kCountOfCountsMax) {
00203 ++m_countOfCounts[countInt];
00204 }
00205 }
00206
00207
00208
00209 if (!fullyLexical && rule.count < m_options.minCountHierarchical) {
00210 continue;
00211 }
00212
00213 TokenizeRuleHalf(rule.target, m_targetHalf);
00214
00215
00216 std::vector<std::pair<std::string, int> >::const_iterator q =
00217 rule.alignments.begin();
00218 const std::pair<std::string, int> *bestAlignmentAndCount = &(*q++);
00219 for (; q != rule.alignments.end(); ++q) {
00220 if (q->second > bestAlignmentAndCount->second) {
00221 bestAlignmentAndCount = &(*q);
00222 }
00223 }
00224 const std::string &bestAlignment = bestAlignmentAndCount->first;
00225 ParseAlignmentString(bestAlignment, m_targetHalf.frontierSymbols.size(),
00226 m_tgtToSrc);
00227
00228
00229 double lexProb = ComputeLexProb(m_sourceHalf.frontierSymbols,
00230 m_targetHalf.frontierSymbols, m_tgtToSrc);
00231
00232
00233 writer.WriteLine(m_sourceHalf, m_targetHalf, bestAlignment, lexProb,
00234 rule.treeScore, p->count, totalCount, distinctCount);
00235 }
00236 }
00237
00238 void ScoreStsg::ParseAlignmentString(const std::string &s, int numTgtWords,
00239 ALIGNMENT &tgtToSrc)
00240 {
00241 tgtToSrc.clear();
00242 tgtToSrc.resize(numTgtWords);
00243
00244 const std::string digits = "0123456789";
00245
00246 std::string::size_type begin = 0;
00247 while (true) {
00248 std::string::size_type end = s.find("-", begin);
00249 if (end == std::string::npos) {
00250 return;
00251 }
00252 int src = std::atoi(s.substr(begin, end-begin).c_str());
00253 if (end+1 == s.size()) {
00254 throw Exception("Target index missing");
00255 }
00256 begin = end+1;
00257 end = s.find_first_not_of(digits, begin+1);
00258 int tgt;
00259 if (end == std::string::npos) {
00260 tgt = std::atoi(s.substr(begin).c_str());
00261 tgtToSrc[tgt].insert(src);
00262 return;
00263 } else {
00264 tgt = std::atoi(s.substr(begin, end-begin).c_str());
00265 tgtToSrc[tgt].insert(src);
00266 }
00267 begin = end+1;
00268 }
00269 }
00270
00271 double ScoreStsg::ComputeLexProb(const std::vector<RuleSymbol> &sourceFrontier,
00272 const std::vector<RuleSymbol> &targetFrontier,
00273 const ALIGNMENT &tgtToSrc)
00274 {
00275 double lexScore = 1.0;
00276 for (std::size_t i = 0; i < targetFrontier.size(); ++i) {
00277 if (targetFrontier[i].isNonTerminal) {
00278 continue;
00279 }
00280 Vocabulary::IdType tgtId = m_tgtVocab.Lookup(targetFrontier[i].value,
00281 StringPieceCompatibleHash(),
00282 StringPieceCompatibleEquals());
00283 const std::set<std::size_t> &srcIndices = tgtToSrc[i];
00284 if (srcIndices.empty()) {
00285
00286 lexScore *= m_lexTable.PermissiveLookup(Vocabulary::NullId(), tgtId);
00287 } else {
00288 double thisWordScore = 0.0;
00289 for (std::set<std::size_t>::const_iterator p = srcIndices.begin();
00290 p != srcIndices.end(); ++p) {
00291 Vocabulary::IdType srcId =
00292 m_srcVocab.Lookup(sourceFrontier[*p].value,
00293 StringPieceCompatibleHash(),
00294 StringPieceCompatibleEquals());
00295 thisWordScore += m_lexTable.PermissiveLookup(srcId, tgtId);
00296 }
00297 lexScore *= thisWordScore / static_cast<double>(srcIndices.size());
00298 }
00299 }
00300 return lexScore;
00301 }
00302
00303 void ScoreStsg::ProcessOptions(int argc, char *argv[], Options &options) const
00304 {
00305 namespace po = boost::program_options;
00306 namespace cls = boost::program_options::command_line_style;
00307
00308
00309
00310 std::ostringstream usageTop;
00311 usageTop << "Usage: " << name()
00312 << " [OPTION]... EXTRACT LEX TABLE\n\n"
00313 << "STSG rule scorer\n\n"
00314 << "Options";
00315
00316
00317 std::ostringstream usageBottom;
00318 usageBottom << "TODO";
00319
00320
00321 po::options_description visible(usageTop.str());
00322 visible.add_options()
00323 ("GoodTuring",
00324 "apply Good-Turing smoothing to relative frequency probability estimates")
00325 ("Hierarchical",
00326 "ignored (included for compatibility with score)")
00327 ("Inverse",
00328 "use inverse mode")
00329 ("KneserNey",
00330 "apply Kneser-Ney smoothing to relative frequency probability estimates")
00331 ("LogProb",
00332 "output log probabilities")
00333 ("MinCountHierarchical",
00334 po::value(&options.minCountHierarchical)->
00335 default_value(options.minCountHierarchical),
00336 "filter out rules with frequency < arg (except fully lexical rules)")
00337 ("NegLogProb",
00338 "output negative log probabilities")
00339 ("NoLex",
00340 "do not compute lexical translation score")
00341 ("NoWordAlignment",
00342 "do not output word alignments")
00343 ("PCFG",
00344 "synonym for TreeScore (included for compatibility with score)")
00345 ("TreeScore",
00346 "include pre-computed tree score from extract")
00347 ("UnpairedExtractFormat",
00348 "ignored (included for compatibility with score)")
00349 ;
00350
00351
00352
00353 po::options_description hidden("Hidden options");
00354 hidden.add_options()
00355 ("ExtractFile",
00356 po::value(&options.extractFile),
00357 "extract file")
00358 ("LexFile",
00359 po::value(&options.lexFile),
00360 "lexical probability file")
00361 ("TableFile",
00362 po::value(&options.tableFile),
00363 "output file")
00364 ;
00365
00366
00367 po::options_description cmdLineOptions;
00368 cmdLineOptions.add(visible).add(hidden);
00369
00370
00371 po::positional_options_description p;
00372 p.add("ExtractFile", 1);
00373 p.add("LexFile", 1);
00374 p.add("TableFile", 1);
00375
00376
00377 po::variables_map vm;
00378 try {
00379 po::store(po::command_line_parser(argc, argv).style(MosesOptionStyle()).
00380 options(cmdLineOptions).positional(p).run(), vm);
00381 po::notify(vm);
00382 } catch (const std::exception &e) {
00383 std::ostringstream msg;
00384 msg << e.what() << "\n\n" << visible << usageBottom.str();
00385 Error(msg.str());
00386 }
00387
00388 if (vm.count("help")) {
00389 std::cout << visible << usageBottom.str() << std::endl;
00390 std::exit(0);
00391 }
00392
00393
00394 if (!vm.count("ExtractFile") ||
00395 !vm.count("LexFile") ||
00396 !vm.count("TableFile")) {
00397 std::ostringstream msg;
00398 std::cerr << visible << usageBottom.str() << std::endl;
00399 std::exit(1);
00400 }
00401
00402
00403 if (vm.count("GoodTuring")) {
00404 options.goodTuring = true;
00405 }
00406 if (vm.count("Inverse")) {
00407 options.inverse = true;
00408 }
00409 if (vm.count("KneserNey")) {
00410 options.kneserNey = true;
00411 }
00412 if (vm.count("LogProb")) {
00413 options.logProb = true;
00414 }
00415 if (vm.count("NegLogProb")) {
00416 options.negLogProb = true;
00417 }
00418 if (vm.count("NoLex")) {
00419 options.noLex = true;
00420 }
00421 if (vm.count("NoWordAlignment")) {
00422 options.noWordAlignment = true;
00423 }
00424 if (vm.count("TreeScore") || vm.count("PCFG")) {
00425 options.treeScore = true;
00426 }
00427 }
00428
00429 }
00430 }
00431 }