TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)
生活随笔
收集整理的这篇文章主要介绍了
TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)
小编觉得挺不错的,现在分享给大家,帮大家做个参考.
TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)
目录
输出结果
设计思路
代码设计
输出结果
第 0 accuracy 0.125 第 20 accuracy 0.6484375 第 40 accuracy 0.78125 第 60 accuracy 0.9296875 第 80 accuracy 0.8671875 第 100 accuracy 0.90625 第 120 accuracy 0.8671875 第 140 accuracy 0.8671875 第 160 accuracy 0.8671875 第 180 accuracy 0.921875 第 200 accuracy 0.890625 第 220 accuracy 0.953125 第 240 accuracy 0.921875 第 260 accuracy 0.9296875 第 280 accuracy 0.9140625 第 300 accuracy 0.921875 第 320 accuracy 0.9609375 第 340 accuracy 0.953125 第 360 accuracy 0.984375 第 380 accuracy 0.921875 第 400 accuracy 0.9453125 第 420 accuracy 0.921875 第 440 accuracy 0.9296875 第 460 accuracy 0.96875 第 480 accuracy 0.984375 第 500 accuracy 0.96875 第 520 accuracy 0.953125 第 540 accuracy 0.96875 第 560 accuracy 0.953125 第 580 accuracy 0.9921875 第 600 accuracy 0.984375 第 620 accuracy 0.953125 第 640 accuracy 0.953125 第 660 accuracy 0.9921875 第 680 accuracy 0.96875 第 700 accuracy 0.9765625 第 720 accuracy 0.96875 第 740 accuracy 0.9921875 第 760 accuracy 0.984375 第 780 accuracy 0.953125
设计思路
代码设计
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True)lr=0.001 training_iters=100000 batch_size=128 n_inputs=28 n_steps=28 n_hidden_units=128 n_classes=10 x=tf.placeholder(tf.float32, [None,n_steps,n_inputs]) y=tf.placeholder(tf.float32, [None,n_classes])weights ={'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),} biases ={'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),}def RNN(X,weights,biases): X=tf.reshape(X,[-1,n_inputs])X_in=tf.matmul(X,weights['in'])+biases['in'] X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)__init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))results=tf.matmul(outputs[-1],weights['out'])+biases['out']return resultspred =RNN(x,weights,biases) cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) train_op=tf.train.AdamOptimizer(lr).minimize(cost) correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32)) <br> with tf.Session() as sess: sess.run(init)step=0while step*batch_size < training_iters: batch_xs,batch_ys=mnist.train.next_batch(batch_size)batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])sess.run([train_op],feed_dict={x:batch_xs,y:batch_ys,})if step%20==0: print(sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys,}))step+=1
相关文章
TF之LSTM:利用LSTM算法对mnist手写数字图片数据集训练、评估(偶尔100%准确度)
总结
以上是生活随笔为你收集整理的TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: AI:一个20年程序猿的学习资料大全—前
- 下一篇: 成功解决h5py\_init_.py:2