edit · history · print

Marian Code Notes

Marian is a plain C++ implementation of neural machine translation models, using no underlying toolkit (Tensorflow, Theano, etc.). Just as DyNet and any other flexible NMT toolkits, it allows you to define a computation graph. The toolkit then uses this to decode and compute derivatives for back-propagation training.

For instance, a basic feed-forward neural network can be mathematically defined as

y = sigmoid ( W_2 sigmoid( W_1 x + b_1 ) + b_2 )

But a "neural network" can be pretty much any mathematical formula that takes input values y and maps them ultimately into output values x.

The task of training is to find the parameter values W_1, W_2, b_1, b_2. This requires to compute the derivative of the cost function with respect to the the parameter weights.

Definition of the computation graph

The core element of the computation graph is a basic computation (addition, matrix multiplication, activation function such as sigmoid etc.) that maps between tensors (i.e., scalars, vectors, matrix, ...). All the relevant code to define a computation graph is in the directory graph.

Operations in the computation graph

This basic element is called Chainable, defined in graph/chainable.h, although it almost always referred to it by its pointer object Expr.

  typedef Ptr<Chainable<Tensor>> Expr;

Typically, each such element has a forward computation (the calculation it carries out on its inputs), a backward operation (the derivative for weight updates), and many more features.

But also the basic root elements of computation graphs, such as input vectors, weight matrices, or bias vectors are Chainable objects.

There are a few more specific types of Chainable.

https://github.com/amunmt/marian/blob/master/src/graph/node.h

 class Node : public Chainable<Tensor>,

Nodes may be results of operations with any number (any arity) of operators.

https://github.com/amunmt/marian/blob/master/src/graph/node.h

 struct NaryNodeOp : public Node

For instance, just unary operators, such as sigmoid, scalar multiplications, etc.

https://github.com/amunmt/marian/blob/master/src/graph/node_operators_unary.h

 struct UnaryNodeOp : public NaryNodeOp

When building a computation graph, these nodes are collected in a ExpressionGraph object (defined in graph/expression_graph.h), which contains these operations in a list.

Example: Sigmoid (also known as logistic function)

The activation function sigmoid is a unary operator, so it is defined in graph/node_operators_unary.h

 struct LogitNodeOp : public UnaryNodeOp {
  template <typename ...Args>
  LogitNodeOp(Args ...args)
  : UnaryNodeOp(args...) {  }

  NodeOps forwardOps() {
    return {
      NodeOp(Element(_1 = Sigma(_2),
                     val_,
                     child(0)->val()))
    };
  }

  NodeOps backwardOps() {
    return {
      NodeOp(Add(_1 * _2 * (1.0f - _2),
                 child(0)->grad(),
                 adj_, val_))
    };
  }

  const std::string type() {
    return "logit";
  }
 };

The definition includes the forward (inference) operator forwardOps and the backward (training) operator backwardOps. Ultimately, these basic node operations have to be execute on the GPU. This requires definition of CUDA kernels that will be passed to the GPU alongside the data which they run on. The kernels and CUDA calls are in kernels. From the above example, Element, Sigma, and Add are defined in this GPU kernel code.

Let us take a closer look at

 NodeOp(Element(_1 = Sigma(_2),
                val_,
                child(0)->val()))

There is a lot of unusual C++ code in here.

The first thing to note is that forwardOps() and backwardOps() return functions. These are defined with the help of NodeOp, which creates lambda functions:

 #define NodeOp(op) [ = ]() { op; }

This cryptic line of C++11 takes a operation op and turns it into a function. Some background on lambda function is here. The [ = ] makes it clear that this defines a lambda function using no arguments (). The body of the function is the specified op. Lambda functions are defined with [], where the behavior of variable handling may be specified, for instance the [ = ] specification here means "capture any referenced variable by making a copy".

Element (defined in kernels/tensor_operators.h is a generic template that creates element-wise operations, for instance given a vector with values, a function that modifies each value independently of others. In the case above of a sigmoid activation function, the sigmoid is applied to each element in the vector. Sigma is defined in the kernel code.

 template <class Functor, class T1, class T2>
 void Element(Functor functor, T1 out, T2 in) {
  [...]
  gElement<<<blocks, threads>>>(functor,
                                out->data(),
                                out->shape(),
                                in->data(),
                                in->shape(),
                                out->shape() != in->shape());
 }
 template <class Functor>
 __global__ void gElement(Functor functor, [...] ) {
  [...]
  out[index] = functor(out[index], in[inIndex]);
 }

_1 and _2 are placeholders that refer to the following arguments. _1 refers to the first following argument val_ and _2 refers to the second following argument child(0)->val(). This is enabled by the definition of _1 and _2 as placeholders in the thrust library that ships with CUDA. If you are not familiar with that, read the next section.

val_ is used to store the value of the node in its execution. adj_ is used to store the gradient.

Side Note: Placeholders and Lambda Functions

C++11 introduces lamda functions, which allow the easy definitions of functions. For instance, a function that prints out Hello World.

 std::function<void()> hello = []() { std::cout << "Hello World\n"; }
 hello();

To avoid the complex syntax here, Marian uses a template as syntactic sugar. Here, we use a slight variation (spot the difference!), more on that detail later.

 #define NodeOp(op) [&]() { op; }

 auto hello = NodeOp( std::cout << "hello world\n"; );
 hello()

This also prints Hello world.. Note that we are lazy and use auto instead of the function type for our function.

Functions may also internally make use of variables.

 void PrintFloat(float a)
 {
    std::cout << "f(" << a << ")\n";
  }
  float x=2;
  auto pa = NodeOp( PrintFloat(x) );
  pa(); // prints f(2)
  x=5;
  pa(); // prints f(5)

Note that the variable a is set internally by reference to the variable used in the definition of pa. The generated function pa does not take arguments itself. The lambda definition [&] defines this behavior of call by reference. In Marin, [ = ] is used to copy pointers to objects.

We can also define a generic templates for all kinds of variables we may want to print.

  template<class A>
  void PrintGeneric(A a)
  {
    std::cout << "f(" << a << ")\n";
  }

This does the same thing. But could also be used for other data types.

Finally, the big one. We now use a function template that uses a functor.

Functors

  namespace thrust::placeholders;
  template<class F, class T1, class T2>
  void BinaryFunction(F f, T1 a,T2 b)
  {
     std::cout << "f(" << a << "," << b << ")\n";
     f(a, b);
  }

  float x = 2,y = 3;
  auto bf = NodeOp( BinaryFunction( _1 = _2, x, y) );
  std::cerr << "x = " << x << std::endl;     // prints x = 2
  bf();                                      // prints f(2,3)
  std::cerr << "x = " << x << std::endl;     // prints x = 3
  y = 5;
  bf();                                      // prints f(2,5)
  std::cerr << "x = " << x << std::endl;     // prints x = 5

Here, we pass to the function template BinaryFunction not only variables x and y, but also a functor _1 = _2. The placeholder _1 refers to x and _2 refers to y, the 1st and 2nd argument after the functor, respectively. These placeholders we use here are defined in thrust::placeholders but there are also variants of this in other libraries.

Example: Affine transform (fancy name for the matrix multiplication in a neural network layer)

 struct AffineNodeOp : public NaryNodeOp {
  AffineNodeOp(const std::vector<Expr>& nodes)
    : NaryNodeOp(nodes, keywords::shape=newShape(nodes)) { }

  Shape newShape(const std::vector<Expr>& nodes) {
    Shape shape1 = nodes[0]->shape();
    Shape shape2 = nodes[1]->shape();
    UTIL_THROW_IF2(shape1[1] != shape2[0],
                   "matrix product requires dimensions to match");
    shape1.set(1, shape2[1]);
    return shape1;
  }

  NodeOps forwardOps() {
    return {
      NodeOp(
        Prod(getCublasHandle(),
             val_,
             child(0)->val(),
             child(1)->val(),
             false, false);
        Add(_1, val_, child(2)->val());
      )
    };
  }

  NodeOps backwardOps() {
    // D is the adjoint, the matrix of derivatives
    // df/dA += D*B.T
    // df/dB += A.T*D
    // beta set to 1.0 in gemm, C = dot(A,B) + beta * C
    // to sum gradients from different graph parts

    return {
      NodeOp(Prod(getCublasHandle(),
                  child(0)->grad(),
                  adj_,
                  child(1)->val(),
                  false, true, 1.0)),
      NodeOp(Prod(getCublasHandle(),
                  child(1)->grad(),
                  child(0)->val(),
                  adj_,
                  true, false, 1.0)),
      NodeOp(Add(_1, child(2)->grad(), adj_))
    };
  }

  const std::string type() {
    return "affine";
  }
 };

Generic Layers

From these basic node operators, more complex node operators can be built. A class of useful more complex operators are neural network layers. These are defined in layers/generic.h.

A basic example is a classic feed-forward layer with an activation function

 class DenseNew : public Layer 

Its parameters are a weight matrix and a bias vector.

  auto W = g->param(name_ + "_W", {in->shape()[1], outDim_},
                          keywords::init=inits::glorot_uniform);
  auto b = g->param(name_ + "_b", {1, outDim_},
                            keywords::init=inits::zeros);

The computation involved are the matrix multiplication

 out = affine(in, W, b);

and the activation function, for instance tanh:

 tanh(out)

Attention layer

A NMT-specific example of a layer is that is defined layers/attention.h.

  auto attReduce = attOps(va_, mappedContext_, mappedState);
  auto e = reshape(transpose(softmax(transpose(attReduce), softmaxMask_)),
                       {dimBatch, 1, srcWords, dimBeam});
  auto alignedSource = weighted_average(encState_->getContext(), e, axis=2);

GRU unit

class GRU

https://github.com/amunmt/marian/blob/master/src/layers/rnn.h

Defining a computation graph

The computation graph of the sequence-to-sequence model with attention, the current state-of-the-art NMT model, consists of an encoder and a decoder. It defined in src/models/s2s.h.

Encoder

The encoder is defined in class EncoderS2S

Source word embeddings:

  auto xEmb = Embedding(prefix_ + "_Wemb", dimSrcVoc, dimSrcEmb)(graph);

Forward ("right-to-left") RNN

  auto xFw = RNN<GRU>(graph, prefix_ + "_bi",
                          dimSrcEmb, dimEncState,
                          normalize=layerNorm,
                          dropout_prob=dropoutRnn)
                         (x);

Backward ("left-to-right") RNN

  auto xBw = RNN<GRU>(graph, prefix_ + "_bi_r",
                          dimSrcEmb, dimEncState,
                          normalize=layerNorm,
                          direction=dir::backward,
                          dropout_prob=dropoutRnn)
                         (x, mask=xMask);

Concatenation of forward and backward

  auto xContext = concatenate({xFw, xBw}, axis=1);

class DecoderS2S

        attention_ = New<GlobalAttention>(prefix_,
                                          state->getEncoderState(),
                                          dimDecState,
                                          dropout_prob=dropoutRnn,
                                          normalize=layerNorm);



        rnnL1 = New<RNN<CGRU>>(graph, prefix_,
                               dimTrgEmb, dimDecState,
                               attention_,
                               dropout_prob=dropoutRnn,
                               normalize=layerNorm);

Decoding

Training

edit · history · print
Page last modified on June 06, 2017, at 08:26 PM