Tensorflow LSTM原理

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: Tensorflow LSTM原理

LSTM是RNN的一种cell实现。那么什么是RNN呢?RNN是一种特殊的神经网络结构, 它是根据“人的认知是基于过往的经验和记忆”这一观点而提出的,它使网络对前面的内容的一种“记忆”功能。隐藏层中,一个序列当前的输出与前面的输出也有关,具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出,所有前面的结果可以影响后面的输出。所以它广泛运用于文本生成,机器翻译,机器写小说,语音识别等。
RNN整个网络架构如下所示:

可以看李弘毅的视频:25.循环神经网络RNN(I)_高清了解更多的RNN原理

LSTM原理

本文只说明LSTM的原理和一个基于Tensorflow的简单测试,测试代码以下图为案例:

如下图所示,LSTM单元接受3个输入:当前输入x_t,上一Cell状态c_t-1,上一次输出h_t-1。
1 上一次输出h_t-1和当前输入x_t按y轴加在一起,和权值矩阵W相乘,然后按y轴均分成4组i,j(z_2),f,o。
2 上一次状态C_t-1和遗忘开关f_t相乘,完成遗忘的处理;和受输入控制的i_t相加(i_t*j_t)完成状态的更新(c_t=C_t-1*f_t+i_t*j_t)。
3 更新后的状态C_t和输出开关o_t相乘,输出h_t。

LSTM测试

如下面的测试代码:
h_t-1为1X1矩阵,x_t为1X3矩阵,加在一起为1X4矩阵,和权值4X4矩阵相乘,输出1X4矩阵,i,j,f,o都是1X1矩阵。
注意:
为了和案例一致,激活函数为linear,即输出自身。案例上的输入bias由h_t-1提供,且强制设成1以满足案例。

from tensorflow.python.ops.rnn_cell_impl import LSTMCell
import tensorflow as tf;
from tensorflow.python.ops import array_ops
from tensorflow.python.framework.ops import numpy_text
 
cell = LSTMCell(1, state_is_tuple=False, forget_bias=0.0, activation="linear")
cell.build(input_shape=[1, 3])
cell._kernel.assign([[  0., 1., 0., 0.],
                     [100., 0., 100., 0.],
                     [  0., 0., 0., 100.],
                     [-10., 0., 10., -10.]])
inputs = tf.constant([[3.0, 1.0, 0.0],
                      [4.0, 1.0, 0.0],
                      [2.0, 0.0, 0.0],
                      [1.0, 0.0, 1.0],
                      [3.0, -1.0, 0.0]])
inputArray = array_ops.split(value=inputs, num_or_size_splits=5, axis=0)
#print(inputArray)
states = tf.constant([[0.0, 1.0]])
cStr = "c values: "
hStr = "h values: "
for inputi in inputArray:
    _, new_state = cell.call(inputi, states)
    print("====output====")
    print(new_state)
    c, h = array_ops.split(value=new_state, num_or_size_splits=2, axis=1)
    states = array_ops.concat([c, tf.ones_like(h)], 1)
    cStr = cStr + numpy_text(c) + " "
    hStr = hStr + numpy_text(h) + " "
print(cStr)
print(hStr)

输出如下,结果h和案例图上的y是一致的。

====output====
tf.Tensor([[3.0000000e+00 1.3619362e-04]], shape=(1, 2), dtype=float32)
====output====
tf.Tensor([[7.000000e+00 3.177851e-04]], shape=(1, 2), dtype=float32)
====output====
tf.Tensor([[6.9997725e+00 3.1777477e-04]], shape=(1, 2), dtype=float32)
====output====
tf.Tensor([[6.9995 6.9995]], shape=(1, 2), dtype=float32)
====output====
tf.Tensor([[0. 0.]], shape=(1, 2), dtype=float32)
c values: [[3.]] [[7.]] [[6.9997725]] [[6.9995]] [[0.]] 
h values: [[0.00013619]] [[0.00031779]] [[0.00031777]] [[6.9995]] [[0.]]
RNN测试

当然,上面的程序是我们手动循环LSTM Cell来模拟的,我们也可以调用框架的RNN来运行:
需要注意的是,由于不能手动改bias,bias即为上次输出h_prev,所以结果不一样了。

from tensorflow.python.ops.rnn_cell_impl import LSTMCell
import tensorflow as tf;
from tensorflow.python.ops.rnn import static_rnn, static_bidirectional_rnn
 
cell = LSTMCell(1, state_is_tuple=False, forget_bias=0.0, activation="linear")
cell.build(input_shape=[1, 3])
cell._kernel.assign([[  0., 1., 0., 0.],
                     [100., 0., 100., 0.],
                     [  0., 0., 0., 100.],
                     [-10., 0., 10., -10.]])
inputs = [tf.constant([3.0, 1.0, 0.0],shape=[1,3]),
          tf.constant([4.0, 1.0, 0.0],shape=[1,3]),
          tf.constant([2.0, 0.0, 0.0],shape=[1,3]),
          tf.constant([1.0, 0.0, 1.0],shape=[1,3]),
          tf.constant([3.0, -1.0, 0.0],shape=[1,3])]
states = tf.constant([[0.0, 1.0]])
outputs, state = static_rnn(cell, inputs, initial_state=states)
print(outputs)
print(state)
 
outputs, output_state_fw, output_state_bw = static_bidirectional_rnn(cell, cell, inputs, initial_state_fw=states, initial_state_bw=states)
print(outputs)
print(output_state_fw)
print(output_state_bw)
LSTM的C++ API测试

程序的效果和上面的Python LSTM一样,不过只测试一个input [3.0, 1.0, 0.0]。
输出结果为:
i:1
cs:0.995055
f:1
o:4.53979e-05
ci:0.995055
co:0.75951
h:3.44801e-05

#include <iostream>
#include <tensorflow/c/c_api.h>
#include <tensorflow/c/c_test_util.h>
 
#include <algorithm>
#include <cstddef>
#include <iterator>
#include <memory>
#include <vector>
#include <string.h>
 
using namespace std;
 
static void FloatDeallocator(void* data, size_t, void* arg) {
	delete[] static_cast<float*>(data);
}
 
int main() {
	cout << "!!!Hello World!!!" << endl; // prints !!!Hello World!!!
	cout << "Hello from TensorFlow C library version" << TF_Version() << endl;
 
	TF_Status* s = TF_NewStatus();
	TF_Graph* graph = TF_NewGraph();
 
	int64_t xDims[] = { 1, 3 };
	float* xVals = new float[3] { 3.0, 1.0, 0.0 };
	TF_Tensor* x = TF_NewTensor(TF_FLOAT, xDims, 2, xVals, sizeof(float) * 3,
			&FloatDeallocator, nullptr);
 
	int64_t csDims[] = { 1, 1 };
	float* csVals = new float[1] { 0.0 };
	TF_Tensor* cs = TF_NewTensor(TF_FLOAT, csDims, 2, csVals, sizeof(float) * 1,
			&FloatDeallocator, nullptr);
	int64_t hDims[] = { 1, 1 };
	float* hVals = new float[1] { 1.0 };
	TF_Tensor* h = TF_NewTensor(TF_FLOAT, hDims, 2, hVals, sizeof(float) * 1,
			&FloatDeallocator, nullptr);
 
	int64_t wDims[] = { 4, 4 };
	float* wVals = new float[16] { 0., 1., 0., 0., 100., 0., 100., 0., 0., 0.,
			0., 100., -10., 0., 10., -10. };
	TF_Tensor* w = TF_NewTensor(TF_FLOAT, wDims, 2, wVals, sizeof(float) * 16,
			&FloatDeallocator, nullptr);
	int64_t bDims[] = { 4 };
	float* bVals = new float[16] { 0., 0., 0., 0. };
	TF_Tensor* b = TF_NewTensor(TF_FLOAT, bDims, 1, bVals, sizeof(float) * 4,
			&FloatDeallocator, nullptr);
 
	int64_t wciDims[] = { 1 };
	float* wciVals = new float[1] { 0.5 };
	TF_Tensor* wci = TF_NewTensor(TF_FLOAT, wciDims, 1, wciVals,
			sizeof(float) * 1, &FloatDeallocator, nullptr);
	int64_t wcfDims[] = { 1 };
	float* wcfVals = new float[1] { 0.5 };
	TF_Tensor* wcf = TF_NewTensor(TF_FLOAT, wcfDims, 1, wcfVals,
			sizeof(float) * 1, &FloatDeallocator, nullptr);
	int64_t wcoDims[] = { 1 };
	float* wcoVals = new float[1] { 0.5 };
	TF_Tensor* wco = TF_NewTensor(TF_FLOAT, wcoDims, 1, wcoVals,
			sizeof(float) * 1, &FloatDeallocator, nullptr);
 
	TF_OperationDescription* desc = TF_NewOperation(graph, "LSTMBlockCell",
			"lstmCell");
	TF_Operation* xInput = Const(x, graph, s, "x");
	TF_Operation* csInput = Const(cs, graph, s, "cs_prev");
	TF_Operation* hInput = Const(h, graph, s, "h_prev");
	TF_Operation* wInput = Const(w, graph, s, "w");
	TF_Operation* wciInput = Const(wci, graph, s, "wci");
	TF_Operation* wcfInput = Const(wcf, graph, s, "wcf");
	TF_Operation* wcoInput = Const(wco, graph, s, "wco");
	TF_Operation* bInput = Const(b, graph, s, "b");
	TF_AddInput(desc, { xInput, 0 });
	TF_AddInput(desc, { csInput, 0 });
	TF_AddInput(desc, { hInput, 0 });
	TF_AddInput(desc, { wInput, 0 });
	TF_AddInput(desc, { wciInput, 0 });
	TF_AddInput(desc, { wcfInput, 0 });
	TF_AddInput(desc, { wcoInput, 0 });
	TF_AddInput(desc, { bInput, 0 });
	TF_SetAttrType(desc, "T", TF_FLOAT);
	TF_SetAttrFloat(desc, "forget_bias", 0.0);
	TF_SetAttrFloat(desc, "cell_clip", 10.0);
	TF_SetAttrBool(desc, "use_peephole", false);
	TF_Operation *op = TF_FinishOperation(desc, s);
	cout << TF_Message(s);
	cout << TF_Message(s);
 
	TF_SessionOptions* opts = TF_NewSessionOptions();
	TF_Session* sess = TF_NewSession(graph, opts, s);
	TF_DeleteSessionOptions(opts);
 
	TF_Output feeds[] = { };
	TF_Output fetches[] = { TF_Output { op, 0 }, TF_Output { op, 1 },
			TF_Output { op, 2 }, TF_Output { op, 3 }, TF_Output { op, 4 },
			TF_Output { op, 5 }, TF_Output { op, 6 } };
 
	const char* handle = nullptr;
	TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
			TF_ARRAYSIZE(fetches), NULL, 0, &handle, s);
	cout << TF_Message(s);
 
	// Feed A and fetch A + 2.
	TF_Output feeds1[] = { };
	TF_Output fetches1[] = { TF_Output { op, 0 }, TF_Output { op, 1 },
			TF_Output { op, 2 }, TF_Output { op, 3 }, TF_Output { op, 4 },
			TF_Output { op, 5 }, TF_Output { op, 6 } };
	TF_Tensor* feedValues1[] = { };
	TF_Tensor* fetchValues1[7];
	TF_SessionPRun(sess, handle, feeds1, feedValues1, 0, fetches1, fetchValues1,
			7, NULL, 0, s);
	cout << TF_Message(s);
	for (int k = 0; k < 7; k++) {
		float* data = static_cast<float*>(TF_TensorData(fetchValues1[k]));
		size_t size = TF_TensorByteSize(fetchValues1[k]) / sizeof(float);
		for (size_t i = 0; i < size; i++) {
			cout << data[i] << endl;
		}
	}
	// Clean up.
	TF_DeletePRunHandle(handle);
	TF_DeleteSession(sess, s);
	cout << TF_Message(s);
	TF_DeleteGraph(graph);
	TF_DeleteStatus(s);
	return 0;
}

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复