00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "ScfgRuleWriter.h"
00021
00022 #include <cassert>
00023 #include <cmath>
00024 #include <ostream>
00025 #include <map>
00026 #include <sstream>
00027 #include <vector>
00028
00029 #include "Alignment.h"
00030 #include "Options.h"
00031 #include "ScfgRule.h"
00032
00033 namespace MosesTraining
00034 {
00035 namespace Syntax
00036 {
00037 namespace GHKM
00038 {
00039
00040 void ScfgRuleWriter::Write(const ScfgRule &rule, size_t lineNum, bool printEndl)
00041 {
00042 std::ostringstream sourceSS;
00043 std::ostringstream targetSS;
00044
00045 if (m_options.unpairedExtractFormat) {
00046 WriteUnpairedFormat(rule, sourceSS, targetSS);
00047 } else {
00048 WriteStandardFormat(rule, sourceSS, targetSS);
00049 }
00050
00051
00052 if (m_options.t2s) {
00053
00054 m_fwd << targetSS.str() << " ||| " << sourceSS.str() << " |||";
00055 m_inv << sourceSS.str() << " ||| " << targetSS.str() << " |||";
00056 } else {
00057 m_fwd << sourceSS.str() << " ||| " << targetSS.str() << " |||";
00058 m_inv << targetSS.str() << " ||| " << sourceSS.str() << " |||";
00059 }
00060
00061 const Alignment &alignment = rule.GetAlignment();
00062 for (Alignment::const_iterator p = alignment.begin();
00063 p != alignment.end(); ++p) {
00064 if (m_options.t2s) {
00065
00066 m_fwd << " " << p->second << "-" << p->first;
00067 m_inv << " " << p->first << "-" << p->second;
00068 } else {
00069 m_fwd << " " << p->first << "-" << p->second;
00070 m_inv << " " << p->second << "-" << p->first;
00071 }
00072 }
00073
00074 if (m_options.includeSentenceId) {
00075 if (m_options.t2s) {
00076 m_inv << " ||| " << lineNum;
00077 } else {
00078 m_fwd << " ||| " << lineNum;
00079 }
00080 }
00081
00082
00083 m_fwd << " ||| 1";
00084 m_inv << " ||| 1";
00085
00086
00087 if (m_options.pcfg) {
00088 m_fwd << " ||| " << std::exp(rule.GetPcfgScore());
00089 }
00090
00091 m_fwd << " |||";
00092
00093 if (m_options.sourceLabels && rule.HasSourceLabels()) {
00094 m_fwd << " {{SourceLabels";
00095 rule.PrintSourceLabels(m_fwd);
00096 m_fwd << "}}";
00097 }
00098
00099 if (printEndl) {
00100 m_fwd << std::endl;
00101 m_inv << std::endl;
00102 }
00103 }
00104
00105 void ScfgRuleWriter::WriteStandardFormat(const ScfgRule &rule,
00106 std::ostream &sourceSS,
00107 std::ostream &targetSS)
00108 {
00109 const std::vector<Symbol> &sourceRHS = rule.GetSourceRHS();
00110 const std::vector<Symbol> &targetRHS = rule.GetTargetRHS();
00111
00112 std::map<int, int> sourceToTargetNTMap;
00113 std::map<int, int> targetToSourceNTMap;
00114
00115 const Alignment &alignment = rule.GetAlignment();
00116
00117 for (Alignment::const_iterator p(alignment.begin());
00118 p != alignment.end(); ++p) {
00119 if (sourceRHS[p->first].GetType() == NonTerminal) {
00120 assert(targetRHS[p->second].GetType() == NonTerminal);
00121 sourceToTargetNTMap[p->first] = p->second;
00122 targetToSourceNTMap[p->second] = p->first;
00123 }
00124 }
00125
00126
00127 std::vector<std::string> partsOfSpeech;
00128 if (m_options.partsOfSpeechFactor) {
00129 const Subgraph &graphFragment = rule.GetGraphFragment();
00130 graphFragment.GetPartsOfSpeech(partsOfSpeech);
00131 }
00132
00133
00134 int i = 0;
00135 for (std::vector<Symbol>::const_iterator p(sourceRHS.begin());
00136 p != sourceRHS.end(); ++p, ++i) {
00137 WriteSymbol(*p, sourceSS);
00138 if (p->GetType() == NonTerminal) {
00139 int targetIndex = sourceToTargetNTMap[i];
00140 WriteSymbol(targetRHS[targetIndex], sourceSS);
00141 }
00142 sourceSS << " ";
00143 }
00144 if (m_options.conditionOnTargetLhs) {
00145 WriteSymbol(rule.GetTargetLHS(), sourceSS);
00146 } else {
00147 WriteSymbol(rule.GetSourceLHS(), sourceSS);
00148 }
00149
00150
00151 i = 0;
00152 int targetTerminalIndex = 0;
00153 for (std::vector<Symbol>::const_iterator p(targetRHS.begin());
00154 p != targetRHS.end(); ++p, ++i) {
00155 if (p->GetType() == NonTerminal) {
00156 int sourceIndex = targetToSourceNTMap[i];
00157 WriteSymbol(sourceRHS[sourceIndex], targetSS);
00158 }
00159 WriteSymbol(*p, targetSS);
00160
00161 if (m_options.partsOfSpeechFactor && (p->GetType() != NonTerminal)) {
00162 assert(targetTerminalIndex<partsOfSpeech.size());
00163 targetSS << "|" << partsOfSpeech[targetTerminalIndex];
00164 ++targetTerminalIndex;
00165 }
00166 targetSS << " ";
00167 }
00168 WriteSymbol(rule.GetTargetLHS(), targetSS);
00169 }
00170
00171 void ScfgRuleWriter::WriteUnpairedFormat(const ScfgRule &rule,
00172 std::ostream &sourceSS,
00173 std::ostream &targetSS)
00174 {
00175 const std::vector<Symbol> &sourceRHS = rule.GetSourceRHS();
00176 const std::vector<Symbol> &targetRHS = rule.GetTargetRHS();
00177
00178
00179 std::vector<std::string> partsOfSpeech;
00180 if (m_options.partsOfSpeechFactor) {
00181 const Subgraph &graphFragment = rule.GetGraphFragment();
00182 graphFragment.GetPartsOfSpeech(partsOfSpeech);
00183 }
00184
00185
00186 for (std::vector<Symbol>::const_iterator p(sourceRHS.begin());
00187 p != sourceRHS.end(); ++p) {
00188 WriteSymbol(*p, sourceSS);
00189 sourceSS << " ";
00190 }
00191 if (m_options.conditionOnTargetLhs) {
00192 WriteSymbol(rule.GetTargetLHS(), sourceSS);
00193 } else {
00194 WriteSymbol(rule.GetSourceLHS(), sourceSS);
00195 }
00196
00197
00198 int targetTerminalIndex = 0;
00199 for (std::vector<Symbol>::const_iterator p(targetRHS.begin());
00200 p != targetRHS.end(); ++p) {
00201 WriteSymbol(*p, targetSS);
00202
00203 if (m_options.partsOfSpeechFactor && (p->GetType() != NonTerminal)) {
00204 assert(targetTerminalIndex<partsOfSpeech.size());
00205 targetSS << "|" << partsOfSpeech[targetTerminalIndex];
00206 ++targetTerminalIndex;
00207 }
00208 targetSS << " ";
00209 }
00210 WriteSymbol(rule.GetTargetLHS(), targetSS);
00211 }
00212
00213 void ScfgRuleWriter::WriteSymbol(const Symbol &symbol, std::ostream &out)
00214 {
00215 if (symbol.GetType() == NonTerminal) {
00216 out << "[";
00217 if (m_options.stripBitParLabels) {
00218 size_t pos = symbol.GetValue().find('-');
00219 if (pos == std::string::npos) {
00220 out << symbol.GetValue();
00221 } else {
00222 out << symbol.GetValue().substr(0,pos);
00223 }
00224 } else {
00225 out << symbol.GetValue();
00226 }
00227 out << "]";
00228 } else {
00229 out << symbol.GetValue();
00230 }
00231 }
00232
00233 }
00234 }
00235 }