edit · history · print

AmuNMT Code Notes

AmuNMT is an implementation of the inference algorithm for the attention model for neural machine translation. The attention model is referred tp in the code as "Encoder-Decoder" model due to neural network layers that deal with encoding the input sentence into continuous space representations and decoding layers that produce output words.

From the outside, AmuNMT is a beam search decoder (now using "decoder" as a synonym for "implementation of the inference algorithm for a machine learning model").

The code is in active development, so some of the code snippets referenced below may have changed.

Big Picture

The main function is in common/decoder_main.cpp

  int main(int argc, char* argv[]) {

It maintains a global data structure called God (implemented in common/god.h and common/god.cpp)

  God* god = new God();
  god->Init(argc, argv);

... for global variables, for instance config settings

  size_t cpuThreads = God::Get<size_t>("cpu-threads");

The main function loops through input data

  while(std::getline(God::GetInputStream(), in)) {

It creates a TranslationTask object for a batch of sentences

  TranslationTask(*god, sentences, taskCounter, maxBatchSize);

These tasks are stored in thread pools (there may be multiple threads due to multiple GPUs or multiple CPU cores)

  ThreadPool pool(totalThreads);

... that translate different batches of sentences in parallel).

  pool->enqueue(
           [ = ]{ return TranslationTask(*god, sentences, taskCounter, maxBatchSize); }
   );

In the TranslationTask (implemented in common/translation_task.cpp) sentences are decoded

  std::shared_ptr<Histories> histories = search->Decode(god, *decodeSentences);

... and their translations printed out

  Printer(god, allHistories, strm);

Output is printed via the Printer function in common/printer.h:

  void Printer(const History& history, size_t lineNo, OStream& out) {
    std::string best = Join(God::Postprocess(God::GetTargetVocab()(history.Top().first)));
    LOG(progress) << "Best translation: " << best;

Note: Join(const std::vector<std::string>& words, const std::string del) is a function in common/utils.cpp that connects words with a delimiter, by default space.

In case of just 1-best output, this is simple returns

  out << best << std::endl;

The Printer function also deals with n-best list output and output of alignments.

Data Structures to Represent Sentences

A sentence is read in the main decoder loop decoder_main.cpp into a string

 std::getline(God::GetInputStream(), in)

The Sentence object in common/sentence.h represents a simple sentence of words as Words object that is a vector of Word tokens.

Words is defined in common/types.h as a vector of Word objects.

 typedef std::vector<Word> Words;

A Word is an integer token ID (common/types.h), with special ID for end of sentence and unknown words.

 typedef size_t Word;
 const Word EOS = 0;
 const Word UNK = 1;

The conversion of word strings into Word IDs is done with the help of source and target vocabulary Vocab objects (defined in common/vocab.h). Internally, these vocabularies use a vector and hash to map between strings and IDs.

  std::map<std::string, size_t> str2id_;
  std::vector<std::string> id2str_;

The relevant lines in common/sentence.cpp that map (simple unfactored) word strings to their vocabulary ID representation are

  Split(tab, lineTokens, " ");
  auto processed = god.Preprocess(i, lineTokens);
  words_.push_back(god.GetSourceVocab(0)(processed));

The Sentence object allows for factored representation, so that each input word may consists of factors such as input word, part-of-speech tag, other linguistic markup, or word classes. Each factor is a Words object, so that a Sentence is actually a vector of such Words objects (common/types.h).

 std::vector<Words> words_;

Input to AmuNMT is tab-separeted for the different input factors (which is different from the Moses notation of using the bar | character to separate factors). Hence (common/sentence.cpp):

  Split(line, tabs, "\t");

To maximize GPU usage, several sentences may be translated simultaneously. This is aided by the object Sentences defined in common/sentence.h. It is essentially a vector of pointers to sentences with a shared maximum length.

  class Sentences {
    [...]
    using SentencePtr = std::shared_ptr<Sentence>;
    std::vector<SentencePtr> coll_;
    size_t maxLength_;

Data Structures to Represent Search States

One view of decoding is that it builds sequences of words. Each produced word results in a hypothesis, which is connected to previous hypotheses.

A Hypothesis (defined in common/hypothesis.h is the produced word, its cost (and cost breakdown), and links to the previous hypothesis

  const HypothesisPtr prevHyp_;
  const size_t prevIndex_;
  const size_t word_;
  const float cost_;
  std::vector<SoftAlignmentPtr> alignments_;
  std::vector<float> costBreakdown_;

Note:

  typedef std::shared_ptr<Hypothesis> HypothesisPtr;

A Beam is a vector of Hypothesis objects that have been created in the most recent processing step of the decoding loop.

 typedef std::vector<HypothesisPtr> Beam;

The search graph is recorded in the History (defined in common/history.h) which is the sequence of Beam objects. It also maintains a priority queue of the best scoring completed sentence translations.

 std::vector<Beam> history_;
 std::priority_queue<HypothesisCoord> topHyps_;
 bool normalize_;
 size_t lineNo_;

A helper object is HypothesisCoord which points to a Hypothesis in a Beam

  struct HypothesisCoord {
    size_t i;
    size_t j;
    float cost;
  };

The function to produce n-best lists NBestList NBest(size_t n) shows how these coordinates of Hypothesis objects in the History matrix are used.

  size_t start = bestHypCoord.i;
  size_t j  = bestHypCoord.j;
  HypothesisPtr bestHyp = history_[start][j];

The Add function in the History object is instructive about how a new Beam of Hypothesis objects is processed to update the priority queue of top scoring hypotheses that are complete sentence translations. This is indicated by the production of an end of sentence EOS token or reaching maximum output length (last).

  for (size_t j = 0; j < beam.size(); ++j)
    if(beam[j]->GetWord() == EOS || last) {
      float cost = normalize_ ? beam[j]->GetCost() / history_.size() : beam[j]->GetCost();
      topHyps_.push({ history_.size(), j, cost });
    }

Note that the hypothesis costs may be normalized by the length of the partial translation (setting normalize_). I am not sure why/if that it is a good thing.

I take the liberty to add a mental note to the coverage-normalized scoring of hypothesis proposed by Google (see Section 7 in Wu et al. (2016)) which maybe should be implemented somewhere. Google's normalization of hypothesis scores is used during search, not just applied to final hypotheses, so it would be implemented elsewhere in the code.

Histories is a class that contains a vector of History objects with some additional support functions (common/history.h).

A Result is a sequence of Words with a pointer to the final hypothesis of the search that produced it (defined in common/hypothesis.h).

  typedef std::pair<Words, HypothesisPtr> Result;

An NBestList is a list of Result objects.

  typedef std::vector<Result> NBestList;

Decoding

Sentences to be translated are populated in common/decoder_main.cpp

  while (std::getline(god->GetInputStream(), in)) {
    sentences->push_back(SentencePtr(new Sentence(god, lineNum++, in)));

... and translation tasks are created for a batch of sentences in sequence ...

    if (sentences->size() >= maxiBatch) {
      pool->enqueue(
          [=]{ return TranslationTask(god, sentences, taskCounter); }
      );

There are two types of batch size: maxiBatch here is used for total number of sentences in a TranslationTask. Within a translation task, these may be broken down to smaller batches of size miniBatch.

A TranslationTask in common/translation_task.cpp sorts sentences by length

  sentences->SortByLength();

... selects a batch of sentence (up to miniBatch)...

  miniBatch = god.Get<size_t>("mini-batch");

... to be translated together ...

  std::shared_ptr<Sentences> decodeSentences(new Sentences(taskCounter, bunchId++));
  for (size_t i = 0; i < sentences->size(); ++i) {
    decodeSentences->push_back(sentences->at(i));

... and then decodes each batch (miniBatch) of sentences

    if (decodeSentences->size() >= miniBatch) {
      std::shared_ptr<Histories> histories = search->Process(god, *decodeSentences);

Intialize Main Data Structures

The wrapper around the search is Search::Process in common/search.cpp that also handles pre-processing and post-processing, as well as encoding of the input sentence and decoding into the output sentence as separate function calls.

We start with initializing a History objects for each of the sentences, and store them in the aggregate Histories object ret. The histories record the search graph (consisting of a graph of Hypothesis objects) and will be returned as main result of the search. This result can then be mined for 1-best or n-best translations, or other details.

  std::shared_ptr<Histories> histories(new Histories(god, sentences));

For each sentence, an initial Hypothesis is created. This initial Hypothesis has no cost, no back-pointer to a previous Hypothesis and the end of sentence marker EOS as first word.

  Beam prevHyps(batchSize, HypothesisPtr(new Hypothesis()));

States of each scorer for the current beam are defined as well.

  States states = NewStates();
  States nextStates = NewStates();

Search::NewStates sets up data structures for the scorer states.

  States states(numScorers);

It then loops through the scorers calls each Scorer i to set up search for each sentence:

  for (size_t i = 0; i < numScorers; i++) {
    Scorer &scorer = *scorers_[i];
    states[i].reset(scorer.NewState());
  }

The 4 processing steps are:

  PreProcess(god, sentences, histories, prevHyps);
  Encode(sentences, states);
  Decode(god, sentences, states, nextStates, histories, prevHyps);
  PostProcess();

Preprocessing: Setting up Beam

Search::PreProcess initializes the hypotheses in the beam. I am confused about the following. Each History gets a vector of initial Hypothesis objects (one for each sentence). But there is also a History for each sentence? I strongly assume that - although not a bug - this is a mistake. Each history should only get one hypothesis. But it also does not matter, since beamSizes is set to 1 for each sentence, so the others are ignored.

  for (size_t i = 0; i < histories->size(); ++i) {
    History &history = *histories->at(i).get();
    history.Add(prevHyps);
  }

Encode: Initialize Search, Compute Input Representations

Search::Encode sets up the scorers for each sentence.

  for (size_t i = 0; i < scorers_.size(); i++) {
    Scorer &scorer = *scorers_[i];
    scorer.SetSource(sentences);
    scorer.BeginSentenceState(*states[i], sentences.size());
  }

Afterwards, states contains begin-of-sentence states, while nextStates is still just an empty object.

SetSource does actually a lot of heavy lifting for the encoder-decoder model: It takes the input sentences, looks up their embeddings, and runs the bidirectional recurrent neural network over them.

Decode: Main Decoding Loop

The main decoding function is Search::Decode in common/search.cpp. This implementation takes a very abstract view about how hypotheses are created and scored. It is aware that there are a bunch of Scorer objects that produce and assign costs to next states in the search, but how this is done is left to the implementation of each Scorer object. In fact, a Scorer may not just be the attentional encoder-decoder model that we care about mainly, but also just neural network language models or even traditional n-gram language models. Even when sticking to the attentional encoder-decoder model, there may be multiple such models in ensemble decoding.

More on the implementation of Scorer later, let us first walk through the Search::Decode function. Recall that it may translate multiple sentences at once

  void Search::Decode(
		const God &god,
		const Sentences& sentences,
		const States &states,
		States &nextStates,
		std::shared_ptr<Histories> &histories,
		Beam &prevHyps)
  {

beamSizes keeps track of how many Hypothesis objects exist for each sentence. This will be handy to map Hypothesis in the current list of active hypothesis (see prevHyps directly below) to sentences. Initially, each sentence has 1 Hypothesis, so beamSizes is initialized to a vector of 1s, the size of batchSize.

  size_t batchSize = sentences.size();
  std::vector<size_t> beamSizes(batchSize, 1);

At this point, everything is set up to carry out search. We progress through search by generating one output word at a time. This may go until the sentence length three times the maximum input sentence length is reached, or when no survivors are in the beam (i.e., all active Hypothesis sequences ended in a end-of--sentence token.

This basic loop is set up as follows, we will look inside of it next.

  for (size_t decoderStep = 0; decoderStep < 3 * sentences.GetMaxLength(); ++decoderStep) {
    [...]
    if (survivors.size() == 0) {
      break;
    }
    prevHyps.swap(survivors);
  }

Note that survivors is a single list of Hypothesis objects, which may relate to different sentences. The bookkeeping to map each Hypothesis to a sentence is beamSizes (hyp to sentence, see above) and beams (sentence to hyp, see below).

Inside the loop, each scorer is instructed to score its state.

    for (size_t i = 0; i < scorers_.size(); i++) {
      Scorer &scorer = *scorers_[i];
      State &state = *states[i];
      State &nextState = *nextStates[i];
      scorer.Score(god, state, nextState, beamSizes);
    }

This triggers the main neural network computation to generate a new word. state (a list of states related to the list of hypotheses) is expanded to nextState. The vector beamSize helps with mapping each state to an input sentence.

The number of hypothesis to be generated by the beam search is specified in beamSizes. This is a user-specified value, applied to each sentence.

    if (decoderStep == 0) {
      for (auto& beamSize : beamSizes) {
        beamSize = god.Get<size_t>("beam-size");
      }
    }

The following is doing the softmax, finds the best Hypothesis objects for each sentence and places them into the newly defined list (of lists) beams.

   Beams beams(batchSize);
   bestHyps_->CalcBeam(god, prevHyps, scorers_, filterIndices_, returnAlignment, beams, beamSizes);

beams contains mapping from each sentence to its hypothesis. For each sentence i, beams[i] contains a list of its Hypothesis objects.

Going Deeper into the Code: Beam Progression

Looking under the hood a bit here, at the GPU implementation of CalcBeam in gpu/decoder/best_hyps.h.

First, all the scores from the different scorers and the original hypothesis scores are combined.

The state progression call (scorer.Score, above) produced for each hypothesis a probability distribution over words. We now look that one up:

  mblas::Matrix& Probs = static_cast<mblas::Matrix&>(scorers[0]->GetProbs());

To these, we add the original hypothesis scores. First, we pull out these scores into a vector vCosts, and a copy of that to Costs

  HostVector<float> vCosts;
  for (auto& h : prevHyps) {
    vCosts.push_back(h->GetCost());
  }
  mblas::copy(vCosts.begin(), vCosts.end(), Costs.begin());

The following function adds this vector of scores (one for each hypothesis) to the matrix of scores (one dimension for the hypothesis, another dimension for the predicted words):

  BroadcastVecColumn(weights_.at(scorers[0]->GetName()) * _1 + _2, Probs, Costs);

The scores from all the other scorers are also added to the combined Probs:

  for (size_t i = 1; i < scorers.size(); ++i) {
    mblas::Matrix &currProbs = static_cast<mblas::Matrix&>(scorers[i]->GetProbs());
    Element(_1 + weights_.at(scorers[i]->GetName()) * _2, Probs, currProbs);
  }

This now sets up the optimization problem of finding to top n word predictions for each hypothesis. This is ultimately done with a call a custom nthElement implementation (called from the sub function FindBest). This returns a

  const bool isFirst = (vCosts[0] == 0.0f) ? true : false;
  std::vector<float> bestCosts;
  std::vector<unsigned> bestKeys;
  FindBests(beamSizes, Probs, bestCosts, bestKeys, isFirst);

The n-best of best hypothesis is organized as specified by the beamSizes data structure.

bestKeys contains keys that encode the pair (hypothesis, predicted word).

  size_t wordIndex = bestKeys[i] % Probs.Cols();
  size_t hypIndex  = bestKeys[i] / Probs.Cols();

These two values are then used to create a new Hypothesis object and add it to the new list of active hypotheses.

  HypothesisPtr hyp;
  hyp.reset(new Hypothesis(prevHyps[hypIndex], wordIndex, hypIndex, cost));
  beams[batchMap[i]].push_back(hyp);

Note that batchMap is the reverse lookup table for beamSizes, constructed in this function for the use above.

Decode: Main Decoding Loop, Continued

We now record the state of the search in the History objects for each sentence. For each sentence, a new beam (i.e., list of Hypothesis objects) is added.

  for (size_t i = 0; i < batchSize; ++i) {
      if (!beams[i].empty()) {
        histories->at(i)->Add(beams[i], histories->at(i)->size() == 3 * sentences.at(i)->GetWords().size());
      }
    }

This is generally straightforward. There are two special cases to deal with: (a) the beam may be empty, so nothing gets added to the History (hence the if statement), and (b) the maximum output length is reached, so History will be informed that this is the last Hypothesis no matter what (hence the conditional for the Add function.

The following loops through each sentences (batchID) and adds their Hypothesis objects h to the global survivor list. However, some new Hypothesis may have reached the end of sentence token EOS, so it will not be added to the list and the beamSizes of the corresponding sentence needs to be reduced.

    Beam survivors;
    for (size_t batchID = 0; batchID < batchSize; ++batchID) {
      for (auto& h : beams[batchID]) {
        if (h->GetWord() != EOS) {
          survivors.push_back(h);
        } else {
          --beamSizes[batchID];
        }
      }
    } 

We are almost set up for the next iteration of the loop. What we still have to do is ... something ... it has to copy nextStates[i] to states[i], and be aware of which Hypothesis objects are still around (survivors) ...

    for (size_t i = 0; i < scorers_.size(); i++) {
      scorers_[i]->AssembleBeamState(*nextStates[i], survivors, *states[i]);
    }

Data Structures to Represent States in the Neural Network

typedef std::vector<StatePtr> States;

State

Scorer is a virtual class. We mainly care about

  EncoderDecoder::EncoderDecoder(const std::string& name,
               const YAML::Node& config,
               size_t tab,
               const Weights& model)

gpu/dl4mt/encoder.cu

Encoder::GetContext(const Sentences& source, size_t tab, mblas::Matrix& Context, DeviceVector<int>& dMapping) runs the encoder over the input sentence and hence computes the input representation. In the GPU code. Theoretically, this is a three-dimensional object (multiple sentences * multiple words per sentence * dimensions of the embedding), but it is actually implemented as a two-dimension object, i.e.,. a matrix by concattenating all the sentences together.

Advances step by

  EncoderDecoder::Score / gpu/decoder/encoder_decoder.cu

Actual Matrix level code in

  MakeStep / gpu/dl4mt/decoder.h
edit · history · print
Page last modified on February 23, 2017, at 03:11 PM