lstm_layer.hpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #ifndef CAFFE_LSTM_LAYER_HPP_
  2. #define CAFFE_LSTM_LAYER_HPP_
  3. #include <string>
  4. #include <utility>
  5. #include <vector>
  6. #include "caffe/blob.hpp"
  7. #include "caffe/common.hpp"
  8. #include "caffe/layer.hpp"
  9. #include "caffe/layers/recurrent_layer.hpp"
  10. #include "caffe/net.hpp"
  11. #include "caffe/proto/caffe.pb.h"
  12. namespace caffe {
  13. template <typename Dtype> class RecurrentLayer;
  14. /**
  15. * @brief Processes sequential inputs using a "Long Short-Term Memory" (LSTM)
  16. * [1] style recurrent neural network (RNN). Implemented by unrolling
  17. * the LSTM computation through time.
  18. *
  19. * The specific architecture used in this implementation is as described in
  20. * "Learning to Execute" [2], reproduced below:
  21. * i_t := \sigmoid[ W_{hi} * h_{t-1} + W_{xi} * x_t + b_i ]
  22. * f_t := \sigmoid[ W_{hf} * h_{t-1} + W_{xf} * x_t + b_f ]
  23. * o_t := \sigmoid[ W_{ho} * h_{t-1} + W_{xo} * x_t + b_o ]
  24. * g_t := \tanh[ W_{hg} * h_{t-1} + W_{xg} * x_t + b_g ]
  25. * c_t := (f_t .* c_{t-1}) + (i_t .* g_t)
  26. * h_t := o_t .* \tanh[c_t]
  27. * In the implementation, the i, f, o, and g computations are performed as a
  28. * single inner product.
  29. *
  30. * Notably, this implementation lacks the "diagonal" gates, as used in the
  31. * LSTM architectures described by Alex Graves [3] and others.
  32. *
  33. * [1] Hochreiter, Sepp, and Schmidhuber, Jürgen. "Long short-term memory."
  34. * Neural Computation 9, no. 8 (1997): 1735-1780.
  35. *
  36. * [2] Zaremba, Wojciech, and Sutskever, Ilya. "Learning to execute."
  37. * arXiv preprint arXiv:1410.4615 (2014).
  38. *
  39. * [3] Graves, Alex. "Generating sequences with recurrent neural networks."
  40. * arXiv preprint arXiv:1308.0850 (2013).
  41. */
  42. template <typename Dtype>
  43. class LSTMLayer : public RecurrentLayer<Dtype> {
  44. public:
  45. explicit LSTMLayer(const LayerParameter& param)
  46. : RecurrentLayer<Dtype>(param) {}
  47. virtual inline const char* type() const { return "LSTM"; }
  48. protected:
  49. virtual void FillUnrolledNet(NetParameter* net_param) const;
  50. virtual void RecurrentInputBlobNames(vector<string>* names) const;
  51. virtual void RecurrentOutputBlobNames(vector<string>* names) const;
  52. virtual void RecurrentInputShapes(vector<BlobShape>* shapes) const;
  53. virtual void OutputBlobNames(vector<string>* names) const;
  54. };
  55. /**
  56. * @brief A helper for LSTMLayer: computes a single timestep of the
  57. * non-linearity of the LSTM, producing the updated cell and hidden
  58. * states.
  59. */
  60. template <typename Dtype>
  61. class LSTMUnitLayer : public Layer<Dtype> {
  62. public:
  63. explicit LSTMUnitLayer(const LayerParameter& param)
  64. : Layer<Dtype>(param) {}
  65. virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
  66. const vector<Blob<Dtype>*>& top);
  67. virtual inline const char* type() const { return "LSTMUnit"; }
  68. virtual inline int ExactNumBottomBlobs() const { return 3; }
  69. virtual inline int ExactNumTopBlobs() const { return 2; }
  70. virtual inline bool AllowForceBackward(const int bottom_index) const {
  71. // Can't propagate to sequence continuation indicators.
  72. return bottom_index != 2;
  73. }
  74. protected:
  75. /**
  76. * @param bottom input Blob vector (length 3)
  77. * -# @f$ (1 \times N \times D) @f$
  78. * the previous timestep cell state @f$ c_{t-1} @f$
  79. * -# @f$ (1 \times N \times 4D) @f$
  80. * the "gate inputs" @f$ [i_t', f_t', o_t', g_t'] @f$
  81. * -# @f$ (1 \times N) @f$
  82. * the sequence continuation indicators @f$ \delta_t @f$
  83. * @param top output Blob vector (length 2)
  84. * -# @f$ (1 \times N \times D) @f$
  85. * the updated cell state @f$ c_t @f$, computed as:
  86. * i_t := \sigmoid[i_t']
  87. * f_t := \sigmoid[f_t']
  88. * o_t := \sigmoid[o_t']
  89. * g_t := \tanh[g_t']
  90. * c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t)
  91. * -# @f$ (1 \times N \times D) @f$
  92. * the updated hidden state @f$ h_t @f$, computed as:
  93. * h_t := o_t .* \tanh[c_t]
  94. */
  95. virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
  96. const vector<Blob<Dtype>*>& top);
  97. virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
  98. const vector<Blob<Dtype>*>& top);
  99. /**
  100. * @brief Computes the error gradient w.r.t. the LSTMUnit inputs.
  101. *
  102. * @param top output Blob vector (length 2), providing the error gradient with
  103. * respect to the outputs
  104. * -# @f$ (1 \times N \times D) @f$:
  105. * containing error gradients @f$ \frac{\partial E}{\partial c_t} @f$
  106. * with respect to the updated cell state @f$ c_t @f$
  107. * -# @f$ (1 \times N \times D) @f$:
  108. * containing error gradients @f$ \frac{\partial E}{\partial h_t} @f$
  109. * with respect to the updated cell state @f$ h_t @f$
  110. * @param propagate_down see Layer::Backward.
  111. * @param bottom input Blob vector (length 3), into which the error gradients
  112. * with respect to the LSTMUnit inputs @f$ c_{t-1} @f$ and the gate
  113. * inputs are computed. Computatation of the error gradients w.r.t.
  114. * the sequence indicators is not implemented.
  115. * -# @f$ (1 \times N \times D) @f$
  116. * the error gradient w.r.t. the previous timestep cell state
  117. * @f$ c_{t-1} @f$
  118. * -# @f$ (1 \times N \times 4D) @f$
  119. * the error gradient w.r.t. the "gate inputs"
  120. * @f$ [
  121. * \frac{\partial E}{\partial i_t}
  122. * \frac{\partial E}{\partial f_t}
  123. * \frac{\partial E}{\partial o_t}
  124. * \frac{\partial E}{\partial g_t}
  125. * ] @f$
  126. * -# @f$ (1 \times 1 \times N) @f$
  127. * the gradient w.r.t. the sequence continuation indicators
  128. * @f$ \delta_t @f$ is currently not computed.
  129. */
  130. virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
  131. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
  132. virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
  133. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
  134. /// @brief The hidden and output dimension.
  135. int hidden_dim_;
  136. Blob<Dtype> X_acts_;
  137. };
  138. } // namespace caffe
  139. #endif // CAFFE_LSTM_LAYER_HPP_